Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

execution: avoid decimal overflow and check valid #34399

Merged
merged 23 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions expression/aggregation/agg_to_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,46 +99,49 @@ func (desc *baseFuncDesc) GetTiPBExpr(tryWindowDesc bool) (tp tipb.ExprType) {
}

// AggFuncToPBExpr converts aggregate function to pb.
func AggFuncToPBExpr(sctx sessionctx.Context, client kv.Client, aggFunc *AggFuncDesc) *tipb.Expr {
func AggFuncToPBExpr(sctx sessionctx.Context, client kv.Client, aggFunc *AggFuncDesc, storeType kv.StoreType) (*tipb.Expr, error) {
pc := expression.NewPBConverter(client, sctx.GetSessionVars().StmtCtx)
tp := aggFunc.GetTiPBExpr(false)
if !client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tp)) {
return nil
return nil, errors.New("select request is not supported by client")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not return an error here, it'll be a compatibility-breaker.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These errors are not reported to the client, because these functions are called in order to check the functions can be pushed down or not. if not, the functions are not pushed down without errors but with some warnings.

}

children := make([]*tipb.Expr, 0, len(aggFunc.Args))
for _, arg := range aggFunc.Args {
pbArg := pc.ExprToPB(arg)
if pbArg == nil {
return nil
return nil, errors.New(aggFunc.String() + " can't be converted to PB.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

}
children = append(children, pbArg)
}
retType, err := expression.ToPBFieldTypeWithCheck(aggFunc.RetTp, storeType)
if err != nil {
return nil, errors.Trace(err)
}

if tp == tipb.ExprType_GroupConcat {
orderBy := make([]*tipb.ByItem, 0, len(aggFunc.OrderByItems))
sc := sctx.GetSessionVars().StmtCtx
for _, arg := range aggFunc.OrderByItems {
pbArg := expression.SortByItemToPB(sc, client, arg.Expr, arg.Desc)
if pbArg == nil {
return nil
return nil, errors.New(aggFunc.String() + " can't be converted to PB.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

}
orderBy = append(orderBy, pbArg)
}
// encode GroupConcatMaxLen
GCMaxLen, err := variable.GetSessionOrGlobalSystemVar(sctx.GetSessionVars(), variable.GroupConcatMaxLen)
if err != nil {
sc.AppendWarning(errors.Errorf("Error happened when buildGroupConcat: no system variable named '%s'", variable.GroupConcatMaxLen))
return nil
return nil, errors.Errorf("Error happened when buildGroupConcat: no system variable named '%s'", variable.GroupConcatMaxLen)
}
maxLen, err := strconv.ParseUint(GCMaxLen, 10, 64)
// Should never happen
if err != nil {
sc.AppendWarning(errors.Errorf("Error happened when buildGroupConcat: %s", err.Error()))
return nil
return nil, errors.Errorf("Error happened when buildGroupConcat: %s", err.Error())
}
return &tipb.Expr{Tp: tp, Val: codec.EncodeUint(nil, maxLen), Children: children, FieldType: expression.ToPBFieldType(aggFunc.RetTp), HasDistinct: aggFunc.HasDistinct, OrderBy: orderBy, AggFuncMode: AggFunctionModeToPB(aggFunc.Mode)}
return &tipb.Expr{Tp: tp, Val: codec.EncodeUint(nil, maxLen), Children: children, FieldType: retType, HasDistinct: aggFunc.HasDistinct, OrderBy: orderBy, AggFuncMode: AggFunctionModeToPB(aggFunc.Mode)}, nil
}
return &tipb.Expr{Tp: tp, Children: children, FieldType: expression.ToPBFieldType(aggFunc.RetTp), HasDistinct: aggFunc.HasDistinct, AggFuncMode: AggFunctionModeToPB(aggFunc.Mode)}
return &tipb.Expr{Tp: tp, Children: children, FieldType: retType, HasDistinct: aggFunc.HasDistinct, AggFuncMode: AggFunctionModeToPB(aggFunc.Mode)}, nil
}

// AggFunctionModeToPB converts aggregate function mode to PB.
Expand Down
4 changes: 3 additions & 1 deletion expression/aggregation/agg_to_pb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -65,7 +66,8 @@ func TestAggFunc2Pb(t *testing.T) {
aggFunc, err := NewAggFuncDesc(ctx, funcName, args, hasDistinct)
require.NoError(t, err)
aggFunc.RetTp = funcTypes[i]
pbExpr := AggFuncToPBExpr(ctx, client, aggFunc)
pbExpr, err := AggFuncToPBExpr(ctx, client, aggFunc, kv.UnSpecified)
require.NoError(t, err)
js, err := json.Marshal(pbExpr)
require.NoError(t, err)
require.Equal(t, fmt.Sprintf(jsons[i], hasDistinct), string(js))
Expand Down
27 changes: 6 additions & 21 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,21 +185,14 @@ func (a *baseFuncDesc) typeInfer4Sum(ctx sessionctx.Context) {
switch a.Args[0].GetType().GetType() {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
a.RetTp.SetFlen(mathutil.Min(a.Args[0].GetType().GetFlen()+21, mysql.MaxDecimalWidth))
a.RetTp.SetFlenUnderLimit(a.Args[0].GetType().GetFlen() + 21)
a.RetTp.SetDecimal(0)
if a.Args[0].GetType().GetFlen() < 0 || a.RetTp.GetFlen() > mysql.MaxDecimalWidth {
if a.Args[0].GetType().GetFlen() < 0 {
a.RetTp.SetFlen(mysql.MaxDecimalWidth)
}
case mysql.TypeNewDecimal:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
a.RetTp.SetFlen(a.Args[0].GetType().GetFlen() + 22)
a.RetTp.SetDecimal(a.Args[0].GetType().GetDecimal())
if a.Args[0].GetType().GetFlen() < 0 || a.RetTp.GetFlen() > mysql.MaxDecimalWidth {
a.RetTp.SetFlen(mysql.MaxDecimalWidth)
}
if a.RetTp.GetDecimal() < 0 || a.RetTp.GetDecimal() > mysql.MaxDecimalScale {
a.RetTp.SetDecimal(mysql.MaxDecimalScale)
}
a.RetTp.UpdateFlenAndDecimalUnderLimit(a.Args[0].GetType(), 0, 22)
case mysql.TypeDouble, mysql.TypeFloat:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.SetFlen(mysql.MaxRealWidth)
Expand All @@ -226,20 +219,12 @@ func (a *baseFuncDesc) typeInfer4Avg(ctx sessionctx.Context) {
switch a.Args[0].GetType().GetType() {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
a.RetTp.SetDecimal(types.DivFracIncr)
a.RetTp.SetDecimalUnderLimit(types.DivFracIncr)
flen, _ := mysql.GetDefaultFieldLengthAndDecimal(a.Args[0].GetType().GetType())
a.RetTp.SetFlen(flen + types.DivFracIncr)
a.RetTp.SetFlenUnderLimit(flen + types.DivFracIncr)
case mysql.TypeYear, mysql.TypeNewDecimal:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
if a.Args[0].GetType().GetDecimal() < 0 {
a.RetTp.SetDecimal(mysql.MaxDecimalScale)
} else {
a.RetTp.SetDecimal(mathutil.Min(a.Args[0].GetType().GetDecimal()+types.DivFracIncr, mysql.MaxDecimalScale))
}
a.RetTp.SetFlen(mathutil.Min(mysql.MaxDecimalWidth, a.Args[0].GetType().GetFlen()+types.DivFracIncr))
if a.Args[0].GetType().GetFlen() < 0 {
a.RetTp.SetFlen(mysql.MaxDecimalWidth)
}
a.RetTp.UpdateFlenAndDecimalUnderLimit(a.Args[0].GetType(), types.DivFracIncr, types.DivFracIncr)
case mysql.TypeDouble, mysql.TypeFloat:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.SetFlen(mysql.MaxRealWidth)
Expand Down
12 changes: 6 additions & 6 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -887,11 +887,11 @@ func GetBuiltinList() []string {
}

func (b *baseBuiltinFunc) setDecimalAndFlenForDatetime(fsp int) {
b.tp.SetDecimal(fsp)
b.tp.SetFlen(mysql.MaxDatetimeWidthNoFsp + fsp)
b.tp.SetDecimalUnderLimit(fsp)
b.tp.SetFlenUnderLimit(mysql.MaxDatetimeWidthNoFsp + fsp)
if fsp > 0 {
// Add the length for `.`.
b.tp.SetFlen(b.tp.GetFlen() + 1)
b.tp.SetFlenUnderLimit(b.tp.GetFlen() + 1)
}
}

Expand All @@ -902,10 +902,10 @@ func (b *baseBuiltinFunc) setDecimalAndFlenForDate() {
}

func (b *baseBuiltinFunc) setDecimalAndFlenForTime(fsp int) {
b.tp.SetDecimal(fsp)
b.tp.SetFlen(mysql.MaxDurationWidthNoFsp + fsp)
b.tp.SetDecimalUnderLimit(fsp)
b.tp.SetFlenUnderLimit(mysql.MaxDurationWidthNoFsp + fsp)
if fsp > 0 {
// Add the length for `.`.
b.tp.SetFlen(b.tp.GetFlen() + 1)
b.tp.SetFlenUnderLimit(b.tp.GetFlen() + 1)
}
}
29 changes: 10 additions & 19 deletions expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ func numericContextResultType(ft *types.FieldType) types.EvalType {
func setFlenDecimal4RealOrDecimal(ctx sessionctx.Context, retTp *types.FieldType, arg0, arg1 Expression, isReal bool, isMultiply bool) {
a, b := arg0.GetType(), arg1.GetType()
if a.GetDecimal() != types.UnspecifiedLength && b.GetDecimal() != types.UnspecifiedLength {
retTp.SetDecimal(a.GetDecimal() + b.GetDecimal())
retTp.SetDecimalUnderLimit(a.GetDecimal() + b.GetDecimal())
if !isMultiply {
retTp.SetDecimal(mathutil.Max(a.GetDecimal(), b.GetDecimal()))
retTp.SetDecimalUnderLimit(mathutil.Max(a.GetDecimal(), b.GetDecimal()))
}
if !isReal && retTp.GetDecimal() > mysql.MaxDecimalScale {
retTp.SetDecimal(mysql.MaxDecimalScale)
Expand All @@ -105,12 +105,12 @@ func setFlenDecimal4RealOrDecimal(ctx sessionctx.Context, retTp *types.FieldType
if isMultiply {
digitsInt = a.GetFlen() - a.GetDecimal() + b.GetFlen() - b.GetDecimal()
}
retTp.SetFlen(digitsInt + retTp.GetDecimal() + 1)
retTp.SetFlenUnderLimit(digitsInt + retTp.GetDecimal() + 1)
if isReal {
retTp.SetFlen(mathutil.Min(retTp.GetFlen(), mysql.MaxRealWidth))
retTp.SetFlenUnderLimit(retTp.GetFlen())
return
}
retTp.SetFlen(mathutil.Min(retTp.GetFlen(), mysql.MaxDecimalWidth))
retTp.SetFlenUnderLimit(retTp.GetFlen())
return
}
if isReal {
Expand All @@ -130,20 +130,14 @@ func (c *arithmeticDivideFunctionClass) setType4DivDecimal(retTp, a, b *types.Fi
if decb == types.UnspecifiedFsp {
decb = 0
}
retTp.SetDecimal(deca + precIncrement)
if retTp.GetDecimal() > mysql.MaxDecimalScale {
retTp.SetDecimal(mysql.MaxDecimalScale)
}
retTp.SetDecimalUnderLimit(deca + precIncrement)
if a.GetFlen() == types.UnspecifiedLength {
retTp.SetFlen(mysql.MaxDecimalWidth)
return
}
aPrec := types.DecimalLength2Precision(a.GetFlen(), a.GetDecimal(), mysql.HasUnsignedFlag(a.GetFlag()))
retTp.SetFlen(aPrec + decb + precIncrement)
retTp.SetFlen(types.Precision2LengthNoTruncation(retTp.GetFlen(), retTp.GetDecimal(), mysql.HasUnsignedFlag(retTp.GetFlag())))
if retTp.GetFlen() > mysql.MaxDecimalWidth {
retTp.SetFlen(mysql.MaxDecimalWidth)
}
retTp.SetFlenUnderLimit(aPrec + decb + precIncrement)
retTp.SetFlenUnderLimit(types.Precision2LengthNoTruncation(retTp.GetFlen(), retTp.GetDecimal(), mysql.HasUnsignedFlag(retTp.GetFlag())))
}

func (c *arithmeticDivideFunctionClass) setType4DivReal(retTp *types.FieldType) {
Expand Down Expand Up @@ -883,18 +877,15 @@ func (c *arithmeticModFunctionClass) setType4ModRealOrDecimal(retTp, a, b *types
if a.GetDecimal() == types.UnspecifiedLength || b.GetDecimal() == types.UnspecifiedLength {
retTp.SetDecimal(types.UnspecifiedLength)
} else {
retTp.SetDecimal(mathutil.Max(a.GetDecimal(), b.GetDecimal()))
if isDecimal && retTp.GetDecimal() > mysql.MaxDecimalScale {
retTp.SetDecimal(mysql.MaxDecimalScale)
}
retTp.SetDecimalUnderLimit(mathutil.Max(a.GetDecimal(), b.GetDecimal()))
}

if a.GetFlen() == types.UnspecifiedLength || b.GetFlen() == types.UnspecifiedLength {
retTp.SetFlen(types.UnspecifiedLength)
} else {
retTp.SetFlen(mathutil.Max(a.GetFlen(), b.GetFlen()))
if isDecimal {
retTp.SetFlen(mathutil.Min(retTp.GetFlen(), mysql.MaxDecimalWidth))
retTp.SetFlenUnderLimit(retTp.GetFlen())
return
}
retTp.SetFlen(mathutil.Min(retTp.GetFlen(), mysql.MaxRealWidth))
Expand Down
14 changes: 4 additions & 10 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1934,8 +1934,8 @@ func WrapWithCastAsDecimal(ctx sessionctx.Context, expr Expression) Expression {
return expr
}
tp := types.NewFieldType(mysql.TypeNewDecimal)
tp.SetFlen(expr.GetType().GetFlen())
tp.SetDecimal(expr.GetType().GetDecimal())
tp.SetFlenUnderLimit(expr.GetType().GetFlen())
tp.SetDecimalUnderLimit(expr.GetType().GetDecimal())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use UpdateFlenAndDecimalUnderLimit here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it would change the logic if they are -1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

besides, UpdateFlenAndDecimalUnderLimit aims to maintain valid when update the flen or decimal, where the flen and decimal should be not be -1.


if expr.GetType().EvalType() == types.ETInt {
tp.SetFlen(mysql.MaxIntWidth)
Expand All @@ -1952,14 +1952,8 @@ func WrapWithCastAsDecimal(ctx sessionctx.Context, expr Expression) Expression {
if !isnull && err == nil {
precision, frac := val.PrecisionAndFrac()
castTp := castExpr.GetType()
castTp.SetDecimal(frac)
castTp.SetFlen(precision)
if castTp.GetFlen() > mysql.MaxDecimalWidth {
castTp.SetFlen(mysql.MaxDecimalWidth)
}
if castTp.GetDecimal() > mysql.MaxDecimalScale {
castTp.SetDecimal(mysql.MaxDecimalScale)
}
castTp.SetDecimalUnderLimit(frac)
castTp.SetFlenUnderLimit(precision)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we suppose the setting values are valid after PrecisionAndFrac. it would change the logic if they are -1. besides, UpdateFlenAndDecimalUnderLimit aims to maintain valid when update the flen or decimal, where the flen and decimal should be not be -1.

}
}
return castExpr
Expand Down
18 changes: 10 additions & 8 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
// Set retType to BINARY(0) if all arguments are of type NULL.
if resultFieldType.GetType() == mysql.TypeNull {
types.SetBinChsClnFlag(bf.tp)
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(0)
} else {
maxIntLen := 0
maxFlen := 0
Expand All @@ -160,7 +162,7 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
// and max integer-part length in `maxIntLen`.
for _, argTp := range fieldTps {
if argTp.GetDecimal() > resultFieldType.GetDecimal() {
resultFieldType.SetDecimal(argTp.GetDecimal())
resultFieldType.SetDecimalUnderLimit(argTp.GetDecimal())
}
argIntLen := argTp.GetFlen()
if argTp.GetDecimal() > 0 {
Expand All @@ -181,12 +183,12 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
// For integer, field length = maxIntLen + (1/0 for sign bit)
// For decimal, field length = maxIntLen + maxDecimal + (1/0 for sign bit)
if resultEvalType == types.ETInt || resultEvalType == types.ETDecimal {
resultFieldType.SetFlen(maxIntLen + resultFieldType.GetDecimal())
resultFieldType.SetFlenUnderLimit(maxIntLen + resultFieldType.GetDecimal())
if resultFieldType.GetDecimal() > 0 {
resultFieldType.SetFlen(resultFieldType.GetFlen() + 1)
resultFieldType.SetFlenUnderLimit(resultFieldType.GetFlen() + 1)
}
if !mysql.HasUnsignedFlag(resultFieldType.GetFlag()) {
resultFieldType.SetFlen(resultFieldType.GetFlen() + 1)
resultFieldType.SetFlenUnderLimit(resultFieldType.GetFlen() + 1)
}
bf.tp = resultFieldType
} else {
Expand Down Expand Up @@ -551,8 +553,8 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}

flen, decimal := fixFlenAndDecimalForGreatestAndLeast(args)
sig.getRetTp().SetFlen(flen)
sig.getRetTp().SetDecimal(decimal)
sig.getRetTp().SetFlenUnderLimit(flen)
sig.getRetTp().SetDecimalUnderLimit(decimal)

return sig, nil
}
Expand Down Expand Up @@ -863,8 +865,8 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
}
}
flen, decimal := fixFlenAndDecimalForGreatestAndLeast(args)
sig.getRetTp().SetFlen(flen)
sig.getRetTp().SetDecimal(decimal)
sig.getRetTp().SetFlenUnderLimit(flen)
sig.getRetTp().SetDecimalUnderLimit(decimal)
return sig, nil
}

Expand Down
8 changes: 4 additions & 4 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp
if lhs.GetDecimal() == types.UnspecifiedLength || rhs.GetDecimal() == types.UnspecifiedLength {
resultFieldType.SetDecimal(types.UnspecifiedLength)
} else {
resultFieldType.SetDecimal(mathutil.Max(lhs.GetDecimal(), rhs.GetDecimal()))
resultFieldType.SetDecimalUnderLimit(mathutil.Max(lhs.GetDecimal(), rhs.GetDecimal()))
}
}

Expand Down Expand Up @@ -146,7 +146,7 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp
rhsFlen -= rhs.GetDecimal()
}
flen := maxlen(lhsFlen, rhsFlen) + resultFieldType.GetDecimal() + 1 // account for -1 len fields
resultFieldType.SetFlen(mathutil.Min(flen, mysql.MaxDecimalWidth)) // make sure it doesn't overflow
resultFieldType.SetFlenUnderLimit(flen)

} else {
resultFieldType.SetFlen(maxlen(lhs.GetFlen(), rhs.GetFlen()))
Expand Down Expand Up @@ -225,7 +225,7 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
// Set retType to BINARY(0) if all arguments are of type NULL.
if fieldTp.GetType() == mysql.TypeNull {
fieldTp.SetFlen(0)
fieldTp.SetDecimal(types.UnspecifiedLength)
fieldTp.SetDecimal(0)
types.SetBinChsClnFlag(fieldTp)
}
argTps := make([]types.EvalType, 0, l)
Expand Down Expand Up @@ -748,7 +748,7 @@ func (c *ifNullFunctionClass) getFunction(ctx sessionctx.Context, args []Express
if lhs.GetType() == mysql.TypeNull && rhs.GetType() == mysql.TypeNull {
retTp.SetType(mysql.TypeNull)
retTp.SetFlen(0)
retTp.SetDecimal(-1)
retTp.SetDecimal(0)
types.SetBinChsClnFlag(retTp)
}
evalTps := retTp.EvalType()
Expand Down
Loading