Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix the return type of coalesce when arg type is DATE | tidb-test=pr/2386 #55969

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 10 additions & 62 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,19 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}

fieldTps := make([]*types.FieldType, 0, len(args))
flag := uint(0)
for _, arg := range args {
fieldTps = append(fieldTps, arg.GetType())
flag |= arg.GetType().GetFlag() & mysql.NotNullFlag
}

// Use the aggregated field type as retType.
resultFieldType := types.AggFieldType(fieldTps)
var tempType uint
resultEvalType := types.AggregateEvalType(fieldTps, &tempType)
resultFieldType.SetFlag(tempType)
retEvalTp := resultFieldType.EvalType()
resultFieldType, err := InferType4ControlFuncs(ctx, c.funcName, args...)
if err != nil {
return nil, err
}

resultFieldType.AddFlag(flag)

retEvalTp := resultFieldType.EvalType()
fieldEvalTps := make([]types.EvalType, 0, len(args))
for range args {
fieldEvalTps = append(fieldEvalTps, retEvalTp)
Expand All @@ -140,60 +141,7 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}

bf.tp.AddFlag(resultFieldType.GetFlag())
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(types.UnspecifiedLength)

// Set retType to BINARY(0) if all arguments are of type NULL.
if resultFieldType.GetType() == mysql.TypeNull {
types.SetBinChsClnFlag(bf.tp)
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(0)
} else {
maxIntLen := 0
maxFlen := 0

// Find the max length of field in `maxFlen`,
// and max integer-part length in `maxIntLen`.
for _, argTp := range fieldTps {
if argTp.GetDecimal() > resultFieldType.GetDecimal() {
resultFieldType.SetDecimalUnderLimit(argTp.GetDecimal())
}
argIntLen := argTp.GetFlen()
if argTp.GetDecimal() > 0 {
argIntLen -= argTp.GetDecimal() + 1
}

// Reduce the sign bit if it is a signed integer/decimal
if !mysql.HasUnsignedFlag(argTp.GetFlag()) {
argIntLen--
}
if argIntLen > maxIntLen {
maxIntLen = argIntLen
}
if argTp.GetFlen() > maxFlen || argTp.GetFlen() == types.UnspecifiedLength {
maxFlen = argTp.GetFlen()
}
}
// For integer, field length = maxIntLen + (1/0 for sign bit)
// For decimal, field length = maxIntLen + maxDecimal + (1/0 for sign bit)
if resultEvalType == types.ETInt || resultEvalType == types.ETDecimal {
resultFieldType.SetFlenUnderLimit(maxIntLen + resultFieldType.GetDecimal())
if resultFieldType.GetDecimal() > 0 {
resultFieldType.SetFlenUnderLimit(resultFieldType.GetFlen() + 1)
}
if !mysql.HasUnsignedFlag(resultFieldType.GetFlag()) {
resultFieldType.SetFlenUnderLimit(resultFieldType.GetFlen() + 1)
}
bf.tp = resultFieldType
} else {
bf.tp.SetFlen(maxFlen)
}
// Set the field length to maxFlen for other types.
if bf.tp.GetFlen() > mysql.MaxDecimalWidth {
bf.tp.SetFlen(mysql.MaxDecimalWidth)
}
}
bf.tp = resultFieldType

switch retEvalTp {
case types.ETInt:
Expand Down
254 changes: 191 additions & 63 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package expression

import (
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -61,47 +62,92 @@ func maxlen(lhsFlen, rhsFlen int) int {
return mathutil.Max(lhsFlen, rhsFlen)
}

// InferType4ControlFuncs infer result type for builtin IF, IFNULL, NULLIF, LEAD and LAG.
func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp Expression) (*types.FieldType, error) {
lhs, rhs := lexp.GetType(), rexp.GetType()
resultFieldType := &types.FieldType{}
if lhs.GetType() == mysql.TypeNull {
*resultFieldType = *rhs
// If any of arg is NULL, result type need unset NotNullFlag.
tempFlag := resultFieldType.GetFlag()
types.SetTypeFlag(&tempFlag, mysql.NotNullFlag, false)
resultFieldType.SetFlag(tempFlag)
// If both arguments are NULL, make resulting type BINARY(0).
if rhs.GetType() == mysql.TypeNull {
resultFieldType.SetType(mysql.TypeString)
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(0)
types.SetBinChsClnFlag(resultFieldType)
func setFlenFromArgs(evalType types.EvalType, resultFieldType *types.FieldType, argTps ...*types.FieldType) {
if evalType == types.ETDecimal || evalType == types.ETInt {
maxArgFlen := 0
for i := range argTps {
flagLen := 0
if !mysql.HasUnsignedFlag(argTps[i].GetFlag()) {
flagLen = 1
}
flen := argTps[i].GetFlen() - flagLen
if argTps[i].GetDecimal() != types.UnspecifiedLength {
flen -= argTps[i].GetDecimal()
}
maxArgFlen = maxlen(maxArgFlen, flen)
}
} else if rhs.GetType() == mysql.TypeNull {
*resultFieldType = *lhs
tempFlag := resultFieldType.GetFlag()
types.SetTypeFlag(&tempFlag, mysql.NotNullFlag, false)
resultFieldType.SetFlag(tempFlag)
// For a decimal field, the `length` and `flen` are not the same.
// `length` only holds the binary data, while `flen` represents the number of digits required to display the field, including the negative sign.
// In the current implementation of TiDB, `flen` and `length` are treated as the same, so the `length` of a decimal may be inconsistent with that of MySQL.
resultFlen := maxArgFlen + resultFieldType.GetDecimal() + 1 // account for -1 len fields
resultFieldType.SetFlenUnderLimit(resultFlen)
} else if evalType == types.ETString {
maxLen := 0
for i := range argTps {
argFlen := argTps[i].GetFlen()
if argFlen == types.UnspecifiedLength {
resultFieldType.SetFlen(types.UnspecifiedLength)
return
}
maxLen = maxlen(argFlen, maxLen)
}
resultFieldType.SetFlen(maxLen)
} else {
resultFieldType = types.AggFieldType([]*types.FieldType{lhs, rhs})
var tempFlag uint
evalType := types.AggregateEvalType([]*types.FieldType{lhs, rhs}, &tempFlag)
resultFieldType.SetFlag(tempFlag)
if evalType == types.ETInt {
resultFieldType.SetDecimal(0)
} else {
if lhs.GetDecimal() == types.UnspecifiedLength || rhs.GetDecimal() == types.UnspecifiedLength {
maxLen := 0
for i := range argTps {
maxLen = maxlen(argTps[i].GetFlen(), maxLen)
}
resultFieldType.SetFlen(maxLen)
}
}

func setDecimalFromArgs(evalType types.EvalType, resultFieldType *types.FieldType, argTps ...*types.FieldType) {
if evalType == types.ETInt {
resultFieldType.SetDecimal(0)
} else {
maxDecimal := 0
for i := range argTps {
if argTps[i].GetDecimal() == types.UnspecifiedLength {
resultFieldType.SetDecimal(types.UnspecifiedLength)
} else {
resultFieldType.SetDecimalUnderLimit(mathutil.Max(lhs.GetDecimal(), rhs.GetDecimal()))
return
}
maxDecimal = mathutil.Max(argTps[i].GetDecimal(), maxDecimal)
}
resultFieldType.SetDecimalUnderLimit(maxDecimal)
}
}

// NonBinaryStr means the arg is a string but not binary string
func hasNonBinaryStr(args []*types.FieldType) bool {
for _, arg := range args {
if types.IsNonBinaryStr(arg) {
return true
}
}
return false
}

func hasBinaryStr(args []*types.FieldType) bool {
for _, arg := range args {
if types.IsBinaryStr(arg) {
return true
}
}
return false
}

func addCollateAndCharsetAndFlagFromArgs(ctx sessionctx.Context, funcName string, evalType types.EvalType, resultFieldType *types.FieldType, args ...Expression) error {
switch funcName {
case ast.If, ast.Ifnull, ast.WindowFuncLead, ast.WindowFuncLag:
if len(args) != 2 {
panic("unexpected length of args for if/ifnull/lead/lag")
}
lexp, rexp := args[0], args[1]
lhs, rhs := lexp.GetType(), rexp.GetType()
if types.IsNonBinaryStr(lhs) && !types.IsBinaryStr(rhs) {
ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp)
if err != nil {
return nil, err
return err
}
resultFieldType.SetCollate(ec.Collation)
resultFieldType.SetCharset(ec.Charset)
Expand All @@ -112,7 +158,7 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp
} else if types.IsNonBinaryStr(rhs) && !types.IsBinaryStr(lhs) {
ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp)
if err != nil {
return nil, err
return err
}
resultFieldType.SetCollate(ec.Collation)
resultFieldType.SetCharset(ec.Charset)
Expand All @@ -127,49 +173,131 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp
resultFieldType.SetCollate(mysql.DefaultCollationName)
resultFieldType.SetFlag(0)
}
if evalType == types.ETDecimal || evalType == types.ETInt {
lhsUnsignedFlag, rhsUnsignedFlag := mysql.HasUnsignedFlag(lhs.GetFlag()), mysql.HasUnsignedFlag(rhs.GetFlag())
lhsFlagLen, rhsFlagLen := 0, 0
if !lhsUnsignedFlag {
lhsFlagLen = 1
}
if !rhsUnsignedFlag {
rhsFlagLen = 1
case ast.Case:
if len(args) == 0 {
panic("unexpected length 0 of args for casewhen")
}
ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, args...)
if err != nil {
return err
}
resultFieldType.SetCollate(ec.Collation)
resultFieldType.SetCharset(ec.Charset)
for i := range args {
if mysql.HasBinaryFlag(args[i].GetType().GetFlag()) || !types.IsNonBinaryStr(args[i].GetType()) {
resultFieldType.AddFlag(mysql.BinaryFlag)
break
}
lhsFlen := lhs.GetFlen() - lhsFlagLen
rhsFlen := rhs.GetFlen() - rhsFlagLen
if lhs.GetDecimal() != types.UnspecifiedLength {
lhsFlen -= lhs.GetDecimal()
}
case ast.Coalesce: // TODO ast.Case and ast.Coalesce should be merged into the same branch
argTypes := make([]*types.FieldType, 0)
for _, arg := range args {
argTypes = append(argTypes, arg.GetType())
}

nonBinaryStrExist := hasNonBinaryStr(argTypes)
binaryStrExist := hasBinaryStr(argTypes)
if !binaryStrExist && nonBinaryStrExist {
ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, args...)
if err != nil {
return err
}
if lhs.GetDecimal() != types.UnspecifiedLength {
rhsFlen -= rhs.GetDecimal()
resultFieldType.SetCollate(ec.Collation)
resultFieldType.SetCharset(ec.Charset)
resultFieldType.SetFlag(0)

// hasNonStringType means that there is a type that is not string
hasNonStringType := false
for _, argType := range argTypes {
if !types.IsString(argType.GetType()) {
hasNonStringType = true
break
}
}
flen := maxlen(lhsFlen, rhsFlen) + resultFieldType.GetDecimal() + 1 // account for -1 len fields
resultFieldType.SetFlenUnderLimit(flen)
} else if evalType == types.ETString {
lhsLen, rhsLen := lhs.GetFlen(), rhs.GetFlen()
if lhsLen != types.UnspecifiedLength && rhsLen != types.UnspecifiedLength {
resultFieldType.SetFlen(mathutil.Max(lhsLen, rhsLen))

if hasNonStringType {
resultFieldType.AddFlag(mysql.BinaryFlag)
}
} else if binaryStrExist || !evalType.IsStringKind() {
types.SetBinChsClnFlag(resultFieldType)
} else {
resultFieldType.SetFlen(maxlen(lhs.GetFlen(), rhs.GetFlen()))
resultFieldType.SetCharset(mysql.DefaultCharset)
resultFieldType.SetCollate(mysql.DefaultCollationName)
resultFieldType.SetFlag(0)
}
default:
panic("unexpected function: " + funcName)
}
// Fix decimal for int and string.
resultEvalType := resultFieldType.EvalType()
if resultEvalType == types.ETInt {
return nil
}

// InferType4ControlFuncs infer result type for builtin IF, IFNULL, NULLIF, CASEWHEN, COALESCE, LEAD and LAG.
func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, args ...Expression) (*types.FieldType, error) {
argsNum := len(args)
if argsNum == 0 {
panic("unexpected length 0 of args")
}
nullFields := make([]*types.FieldType, 0, argsNum)
notNullFields := make([]*types.FieldType, 0, argsNum)
for i := range args {
if args[i].GetType().GetType() == mysql.TypeNull {
nullFields = append(nullFields, args[i].GetType())
} else {
notNullFields = append(notNullFields, args[i].GetType())
}
}
resultFieldType := &types.FieldType{}
if len(nullFields) == argsNum { // all field is TypeNull
*resultFieldType = *nullFields[0]
// If any of arg is NULL, result type need unset NotNullFlag.
tempFlag := resultFieldType.GetFlag()
types.SetTypeFlag(&tempFlag, mysql.NotNullFlag, false)
resultFieldType.SetFlag(tempFlag)

resultFieldType.SetType(mysql.TypeNull)
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(0)
if resultFieldType.GetType() == mysql.TypeEnum || resultFieldType.GetType() == mysql.TypeSet {
resultFieldType.SetType(mysql.TypeLonglong)
types.SetBinChsClnFlag(resultFieldType)
} else {
if len(notNullFields) == 1 {
*resultFieldType = *notNullFields[0]
} else {
resultFieldType = types.AggFieldType(notNullFields)
var tempFlag uint
evalType := types.AggregateEvalType(notNullFields, &tempFlag)
resultFieldType.SetFlag(tempFlag)
setDecimalFromArgs(evalType, resultFieldType, notNullFields...)
err := addCollateAndCharsetAndFlagFromArgs(ctx, funcName, evalType, resultFieldType, args...)
if err != nil {
return nil, err
}
setFlenFromArgs(evalType, resultFieldType, notNullFields...)
}
} else if resultEvalType == types.ETString {
if lhs.GetType() != mysql.TypeNull || rhs.GetType() != mysql.TypeNull {

// If any of arg is NULL, result type need unset NotNullFlag.
if len(nullFields) > 0 {
tempFlag := resultFieldType.GetFlag()
types.SetTypeFlag(&tempFlag, mysql.NotNullFlag, false)
resultFieldType.SetFlag(tempFlag)
}

resultEvalType := resultFieldType.EvalType()
// fix decimal for int and string.
if resultEvalType == types.ETInt {
resultFieldType.SetDecimal(0)
} else if resultEvalType == types.ETString {
resultFieldType.SetDecimal(types.UnspecifiedLength)
}
// fix type for enum and set
if resultFieldType.GetType() == mysql.TypeEnum || resultFieldType.GetType() == mysql.TypeSet {
resultFieldType.SetType(mysql.TypeVarchar)
switch resultEvalType {
case types.ETInt:
resultFieldType.SetType(mysql.TypeLonglong)
case types.ETString:
resultFieldType.SetType(mysql.TypeVarchar)
}
}
} else if resultFieldType.GetType() == mysql.TypeDatetime {
// fix flen for datetime
types.TryToFixFlenOfDatetime(resultFieldType)
}
return resultFieldType, nil
Expand Down
Loading
Loading