Skip to content

Commit

Permalink
This is an automated cherry-pick of pingcap#48032
Browse files Browse the repository at this point in the history
Signed-off-by: ti-chi-bot <ti-community-prow-bot@tidb.io>
  • Loading branch information
xzhangxian1008 authored and ti-chi-bot committed Nov 8, 2023
1 parent 8b2e9f9 commit c285279
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 12 deletions.
22 changes: 14 additions & 8 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,19 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}

fieldTps := make([]*types.FieldType, 0, len(args))
flag := uint(0)
for _, arg := range args {
fieldTps = append(fieldTps, arg.GetType())
flag |= arg.GetType().GetFlag() & mysql.NotNullFlag
}

// Use the aggregated field type as retType.
resultFieldType := types.AggFieldType(fieldTps)
var tempType uint
resultEvalType := types.AggregateEvalType(fieldTps, &tempType)
resultFieldType.SetFlag(tempType)
retEvalTp := resultFieldType.EvalType()
resultFieldType, err := InferType4ControlFuncs(ctx, c.funcName, args...)
if err != nil {
return nil, err
}

resultFieldType.AddFlag(flag)

retEvalTp := resultFieldType.EvalType()
fieldEvalTps := make([]types.EvalType, 0, len(args))
for range args {
fieldEvalTps = append(fieldEvalTps, retEvalTp)
Expand All @@ -145,6 +146,7 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}

<<<<<<< HEAD:expression/builtin_compare.go
bf.tp.AddFlag(resultFieldType.GetFlag())
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(types.UnspecifiedLength)
Expand Down Expand Up @@ -197,6 +199,9 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
bf.tp.SetFlen(mysql.MaxDecimalWidth)
}
}
=======
bf.tp = resultFieldType
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/builtin_compare.go

switch retEvalTp {
case types.ETInt:
Expand Down Expand Up @@ -1250,6 +1255,7 @@ func (b *builtinIntervalRealSig) evalInt(row chunk.Row) (int64, bool, error) {
if isNull {
return -1, false, nil
}

var idx int
if b.hasNullable {
idx, err = b.linearSearch(arg0, b.args[1:], row)
Expand Down
26 changes: 26 additions & 0 deletions expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,29 @@ func TestGreatestLeastFunc(t *testing.T) {
_, err = funcs[ast.Least].getFunction(ctx, []Expression{NewZero(), NewOne()})
require.NoError(t, err)
}
<<<<<<< HEAD:expression/builtin_compare_test.go
=======

func TestRefineArgsWithCastEnum(t *testing.T) {
ctx := createContext(t)
zeroUintConst := primitiveValsToConstants(ctx, []interface{}{uint64(0)})[0]
enumType := types.NewFieldTypeBuilder().SetType(mysql.TypeEnum).SetElems([]string{"1", "2", "3"}).AddFlag(mysql.EnumSetAsIntFlag).Build()
enumCol := &Column{RetType: &enumType}

f := funcs[ast.EQ].(*compareFunctionClass)
require.NotNil(t, f)

args := f.refineArgsByUnsignedFlag(ctx, []Expression{zeroUintConst, enumCol})
require.Equal(t, zeroUintConst, args[0])
require.Equal(t, enumCol, args[1])
}

func TestIssue46475(t *testing.T) {
ctx := createContext(t)
args := []interface{}{nil, dt, nil}

f, err := newFunctionForTest(ctx, ast.Coalesce, primitiveValsToConstants(ctx, args)...)
require.NoError(t, err)
require.Equal(t, f.GetType().GetType(), mysql.TypeDate)
}
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/builtin_compare_test.go
100 changes: 100 additions & 0 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,36 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp
}
}

<<<<<<< HEAD:expression/builtin_control.go
=======
// NonBinaryStr means the arg is a string but not binary string
func hasNonBinaryStr(args []*types.FieldType) bool {
for _, arg := range args {
if types.IsNonBinaryStr(arg) {
return true
}
}
return false
}

func hasBinaryStr(args []*types.FieldType) bool {
for _, arg := range args {
if types.IsBinaryStr(arg) {
return true
}
}
return false
}

func addCollateAndCharsetAndFlagFromArgs(ctx sessionctx.Context, funcName string, evalType types.EvalType, resultFieldType *types.FieldType, args ...Expression) error {
switch funcName {
case ast.If, ast.Ifnull, ast.WindowFuncLead, ast.WindowFuncLag:
if len(args) != 2 {
panic("unexpected length of args for if/ifnull/lead/lag")
}
lexp, rexp := args[0], args[1]
lhs, rhs := lexp.GetType(), rexp.GetType()
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/builtin_control.go
if types.IsNonBinaryStr(lhs) && !types.IsBinaryStr(rhs) {
ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp)
if err != nil {
Expand Down Expand Up @@ -152,13 +182,83 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp
if lhsLen != types.UnspecifiedLength && rhsLen != types.UnspecifiedLength {
resultFieldType.SetFlen(mathutil.Max(lhsLen, rhsLen))
}
<<<<<<< HEAD:expression/builtin_control.go
=======
}
case ast.Coalesce: // TODO ast.Case and ast.Coalesce should be merged into the same branch
argTypes := make([]*types.FieldType, 0)
for _, arg := range args {
argTypes = append(argTypes, arg.GetType())
}

nonBinaryStrExist := hasNonBinaryStr(argTypes)
binaryStrExist := hasBinaryStr(argTypes)
if !binaryStrExist && nonBinaryStrExist {
ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, args...)
if err != nil {
return err
}
resultFieldType.SetCollate(ec.Collation)
resultFieldType.SetCharset(ec.Charset)
resultFieldType.SetFlag(0)

// hasNonStringType means that there is a type that is not string
hasNonStringType := false
for _, argType := range argTypes {
if !types.IsString(argType.GetType()) {
hasNonStringType = true
break
}
}

if hasNonStringType {
resultFieldType.AddFlag(mysql.BinaryFlag)
}
} else if binaryStrExist || !evalType.IsStringKind() {
types.SetBinChsClnFlag(resultFieldType)
} else {
resultFieldType.SetCharset(mysql.DefaultCharset)
resultFieldType.SetCollate(mysql.DefaultCollationName)
resultFieldType.SetFlag(0)
}
default:
panic("unexpected function: " + funcName)
}
return nil
}

// InferType4ControlFuncs infer result type for builtin IF, IFNULL, NULLIF, CASEWHEN, COALESCE, LEAD and LAG.
func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, args ...Expression) (*types.FieldType, error) {
argsNum := len(args)
if argsNum == 0 {
panic("unexpected length 0 of args")
}
nullFields := make([]*types.FieldType, 0, argsNum)
notNullFields := make([]*types.FieldType, 0, argsNum)
for i := range args {
if args[i].GetType().GetType() == mysql.TypeNull {
nullFields = append(nullFields, args[i].GetType())
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/builtin_control.go
} else {
resultFieldType.SetFlen(maxlen(lhs.GetFlen(), rhs.GetFlen()))
}
}
<<<<<<< HEAD:expression/builtin_control.go
// Fix decimal for int and string.
resultEvalType := resultFieldType.EvalType()
if resultEvalType == types.ETInt {
=======
resultFieldType := &types.FieldType{}
if len(nullFields) == argsNum { // all field is TypeNull
*resultFieldType = *nullFields[0]
// If any of arg is NULL, result type need unset NotNullFlag.
tempFlag := resultFieldType.GetFlag()
types.SetTypeFlag(&tempFlag, mysql.NotNullFlag, false)
resultFieldType.SetFlag(tempFlag)

resultFieldType.SetType(mysql.TypeNull)
resultFieldType.SetFlen(0)
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/builtin_control.go
resultFieldType.SetDecimal(0)
if resultFieldType.GetType() == mysql.TypeEnum || resultFieldType.GetType() == mysql.TypeSet {
resultFieldType.SetType(mysql.TypeLonglong)
Expand Down
5 changes: 5 additions & 0 deletions expression/expr_to_pb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,13 @@ func TestOtherFunc2Pb(t *testing.T) {
pbExprs, err := ExpressionsToPBList(sc, otherFuncs, client)
require.NoError(t, err)
jsons := map[string]string{
<<<<<<< HEAD:expression/expr_to_pb_test.go
ast.Coalesce: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":-1,\"decimal\":-1,\"collate\":-63,\"charset\":\"binary\"},\"has_distinct\":false}],\"sig\":4201,\"field_type\":{\"tp\":3,\"flag\":128,\"flen\":0,\"decimal\":-1,\"collate\":-63,\"charset\":\"binary\"},\"has_distinct\":false}",
ast.IsNull: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":-1,\"decimal\":-1,\"collate\":-63,\"charset\":\"binary\"},\"has_distinct\":false}],\"sig\":3116,\"field_type\":{\"tp\":8,\"flag\":524416,\"flen\":1,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\"},\"has_distinct\":false}",
=======
ast.Coalesce: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}],\"sig\":4201,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}",
ast.IsNull: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}],\"sig\":3116,\"field_type\":{\"tp\":8,\"flag\":524417,\"flen\":1,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}",
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/expr_to_pb_test.go
}
for i, pbExpr := range pbExprs {
js, err := json.Marshal(pbExpr)
Expand Down
14 changes: 10 additions & 4 deletions expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1018,10 +1018,16 @@ func (s *InferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCase {

func (s *InferTypeSuite) createTestCase4CompareFuncs() []typeInferTestCase {
return []typeInferTestCase{
{"coalesce(c_int_d, 1)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"coalesce(NULL, c_int_d)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"coalesce(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 15, 3},
{"coalesce(c_int_d, c_datetime)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 22, types.UnspecifiedLength},
{"coalesce(c_int_d, c_int_d)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"coalesce(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 14, 3},
{"coalesce(c_int_d, c_char)", mysql.TypeString, charset.CharsetUTF8MB4, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"coalesce(c_int_d, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"coalesce(c_char, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"coalesce(null, null)", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag, 0, 0},
{"coalesce(c_double_d, c_timestamp_d)", mysql.TypeVarchar, charset.CharsetUTF8MB4, 0, 22, types.UnspecifiedLength},
{"coalesce(c_json, c_decimal)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, math.MaxUint32, types.UnspecifiedLength},
{"coalesce(c_time, c_date)", mysql.TypeDatetime, charset.CharsetUTF8MB4, 0, mysql.MaxDatetimeWidthNoFsp + 3 + 1, 3},
{"coalesce(c_time_d, c_date)", mysql.TypeDatetime, charset.CharsetUTF8MB4, 0, mysql.MaxDatetimeWidthNoFsp, 0},

{"isnull(c_int_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0},
{"isnull(c_bigint_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0},
Expand Down

0 comments on commit c285279

Please sign in to comment.