Skip to content

Commit

Permalink
expression: fix type infer for tidb's builtin compare(least and great…
Browse files Browse the repository at this point in the history
…est) (#21150)

Signed-off-by: iosmanthus <myosmanthustree@gmail.com>
  • Loading branch information
iosmanthus authored Dec 22, 2020
1 parent cf806f6 commit dd0dc46
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 133 deletions.
178 changes: 107 additions & 71 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package expression

import (
"math"
"strings"

"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
Expand Down Expand Up @@ -367,53 +368,67 @@ func (b *builtinCoalesceJSONSig) evalJSON(row chunk.Row) (res json.BinaryJSON, i
return res, isNull, err
}

// temporalWithDateAsNumEvalType makes DATE, DATETIME, TIMESTAMP pretend to be numbers rather than strings.
func temporalWithDateAsNumEvalType(argTp *types.FieldType) (argEvalType types.EvalType, isStr bool, isTemporalWithDate bool) {
argEvalType = argTp.EvalType()
isStr, isTemporalWithDate = argEvalType.IsStringKind(), types.IsTemporalWithDate(argTp.Tp)
if !isTemporalWithDate {
return
func aggregateType(args []Expression) *types.FieldType {
fieldTypes := make([]*types.FieldType, len(args))
for i := range fieldTypes {
fieldTypes[i] = args[i].GetType()
}
if argTp.Decimal > 0 {
argEvalType = types.ETDecimal
} else {
argEvalType = types.ETInt
}
return
return types.AggFieldType(fieldTypes)
}

// GetCmpTp4MinMax gets compare type for GREATEST and LEAST and BETWEEN
func GetCmpTp4MinMax(args []Expression) (argTp types.EvalType) {
datetimeFound, isAllStr := false, true
cmpEvalType, isStr, isTemporalWithDate := temporalWithDateAsNumEvalType(args[0].GetType())
if !isStr {
isAllStr = false
// ResolveType4Between resolves eval type for between expression.
func ResolveType4Between(args [3]Expression) types.EvalType {
cmpTp := args[0].GetType().EvalType()
for i := 1; i < 3; i++ {
cmpTp = getBaseCmpType(cmpTp, args[i].GetType().EvalType(), nil, nil)
}
if isTemporalWithDate {
datetimeFound = true
}
lft := args[0].GetType()
for i := range args {
rft := args[i].GetType()
var tp types.EvalType
tp, isStr, isTemporalWithDate = temporalWithDateAsNumEvalType(rft)
if isTemporalWithDate {
datetimeFound = true

hasTemporal := false
if cmpTp == types.ETString {
for _, arg := range args {
if types.IsTypeTemporal(arg.GetType().Tp) {
hasTemporal = true
break
}
}
if !isStr {
isAllStr = false
if hasTemporal {
cmpTp = types.ETDatetime
}
cmpEvalType = getBaseCmpType(cmpEvalType, tp, lft, rft)
lft = rft
}
argTp = cmpEvalType
if cmpEvalType.IsStringKind() {
argTp = types.ETString

return cmpTp
}

// resolveType4Extremum gets compare type for GREATEST and LEAST and BETWEEN (mainly for datetime).
func resolveType4Extremum(args []Expression) types.EvalType {
aggType := aggregateType(args)

var temporalItem *types.FieldType
if aggType.EvalType().IsStringKind() {
for i := range args {
item := args[i].GetType()
if types.IsTemporalWithDate(item.Tp) {
temporalItem = item
}
}

if !types.IsTemporalWithDate(aggType.Tp) && temporalItem != nil {
aggType.Tp = temporalItem.Tp
}
// TODO: String charset, collation checking are needed.
}
if isAllStr && datetimeFound {
argTp = types.ETDatetime
return aggType.EvalType()
}

// unsupportedJSONComparison reports warnings while there is a JSON type in least/greatest function's arguments
func unsupportedJSONComparison(ctx sessionctx.Context, args []Expression) {
for _, arg := range args {
tp := arg.GetType().Tp
if tp == mysql.TypeJSON {
ctx.GetSessionVars().StmtCtx.AppendWarning(errUnsupportedJSONComparison)
break
}
}
return argTp
}

type greatestFunctionClass struct {
Expand All @@ -424,10 +439,14 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err = c.verifyArgs(args); err != nil {
return nil, err
}
tp, cmpAsDatetime := GetCmpTp4MinMax(args), false
if tp == types.ETDatetime {
tp := resolveType4Extremum(args)
cmpAsDatetime := false
if tp == types.ETDatetime || tp == types.ETTimestamp {
cmpAsDatetime = true
tp = types.ETString
} else if tp == types.ETJson {
unsupportedJSONComparison(ctx, args)
tp = types.ETString
}
argTps := make([]types.EvalType, len(args))
for i := range args {
Expand All @@ -453,7 +472,7 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
case types.ETString:
sig = &builtinGreatestStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestString)
case types.ETDatetime:
case types.ETDatetime, types.ETTimestamp:
sig = &builtinGreatestTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestTime)
}
Expand Down Expand Up @@ -592,30 +611,39 @@ func (b *builtinGreatestTimeSig) Clone() builtinFunc {

// evalString evals a builtinGreatestTimeSig.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_greatest
func (b *builtinGreatestTimeSig) evalString(row chunk.Row) (_ string, isNull bool, err error) {
func (b *builtinGreatestTimeSig) evalString(row chunk.Row) (res string, isNull bool, err error) {
var (
v string
t types.Time
strRes string
timeRes types.Time
)
max := types.ZeroDatetime
sc := b.ctx.GetSessionVars().StmtCtx
for i := 0; i < len(b.args); i++ {
v, isNull, err = b.args[i].EvalString(b.ctx, row)
v, isNull, err := b.args[i].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, err
}
t, err = types.ParseDatetime(sc, v)
t, err := types.ParseDatetime(sc, v)
if err != nil {
if err = handleInvalidTimeError(b.ctx, err); err != nil {
return v, true, err
}
continue
} else {
v = t.String()
}
// In MySQL, if the compare result is zero, than we will try to use the string comparison result
if i == 0 || strings.Compare(v, strRes) > 0 {
strRes = v
}
if t.Compare(max) > 0 {
max = t
if i == 0 || t.Compare(timeRes) > 0 {
timeRes = t
}
}
return max.String(), false, nil
if timeRes.IsZero() {
res = strRes
} else {
res = timeRes.String()
}
return res, false, nil
}

type leastFunctionClass struct {
Expand All @@ -626,10 +654,14 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
if err = c.verifyArgs(args); err != nil {
return nil, err
}
tp, cmpAsDatetime := GetCmpTp4MinMax(args), false
tp := resolveType4Extremum(args)
cmpAsDatetime := false
if tp == types.ETDatetime {
cmpAsDatetime = true
tp = types.ETString
} else if tp == types.ETJson {
unsupportedJSONComparison(ctx, args)
tp = types.ETString
}
argTps := make([]types.EvalType, len(args))
for i := range args {
Expand Down Expand Up @@ -796,32 +828,36 @@ func (b *builtinLeastTimeSig) Clone() builtinFunc {
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#functionleast
func (b *builtinLeastTimeSig) evalString(row chunk.Row) (res string, isNull bool, err error) {
var (
v string
t types.Time
// timeRes will be converted to a strRes only when the arguments is a valid datetime value.
strRes string // Record the strRes of each arguments.
timeRes types.Time // Record the time representation of a valid arguments.
)
min := types.NewTime(types.MaxDatetime, mysql.TypeDatetime, types.MaxFsp)
findInvalidTime := false
sc := b.ctx.GetSessionVars().StmtCtx
for i := 0; i < len(b.args); i++ {
v, isNull, err = b.args[i].EvalString(b.ctx, row)
v, isNull, err := b.args[i].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, err
}
t, err = types.ParseDatetime(sc, v)
t, err := types.ParseDatetime(sc, v)
if err != nil {
if err = handleInvalidTimeError(b.ctx, err); err != nil {
return v, true, err
} else if !findInvalidTime {
res = v
findInvalidTime = true
}
} else {
v = t.String()
}
if i == 0 || strings.Compare(v, strRes) < 0 {
strRes = v
}
if t.Compare(min) < 0 {
min = t
if i == 0 || t.Compare(timeRes) < 0 {
timeRes = t
}
}
if !findInvalidTime {
res = min.String()

if timeRes.IsZero() {
res = strRes
} else {
res = timeRes.String()
}
return res, false, nil
}
Expand Down Expand Up @@ -1042,7 +1078,7 @@ type compareFunctionClass struct {

// getBaseCmpType gets the EvalType that the two args will be treated as when comparing.
func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.EvalType {
if lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified {
if lft != nil && rft != nil && (lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified) {
if lft.Tp == rft.Tp {
return types.ETString
}
Expand All @@ -1054,13 +1090,13 @@ func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.Ev
}
if lhs.IsStringKind() && rhs.IsStringKind() {
return types.ETString
} else if (lhs == types.ETInt || lft.Hybrid()) && (rhs == types.ETInt || rft.Hybrid()) {
} else if (lhs == types.ETInt || (lft != nil && lft.Hybrid())) && (rhs == types.ETInt || (rft != nil && rft.Hybrid())) {
return types.ETInt
} else if ((lhs == types.ETInt || lft.Hybrid()) || lhs == types.ETDecimal) &&
((rhs == types.ETInt || rft.Hybrid()) || rhs == types.ETDecimal) {
} else if ((lhs == types.ETInt || (lft != nil && lft.Hybrid())) || lhs == types.ETDecimal) &&
((rhs == types.ETInt || (rft != nil && rft.Hybrid())) || rhs == types.ETDecimal) {
return types.ETDecimal
} else if types.IsTemporalWithDate(lft.Tp) && rft.Tp == mysql.TypeYear ||
lft.Tp == mysql.TypeYear && types.IsTemporalWithDate(rft.Tp) {
} else if lft != nil && rft != nil && (types.IsTemporalWithDate(lft.Tp) && rft.Tp == mysql.TypeYear ||
lft.Tp == mysql.TypeYear && types.IsTemporalWithDate(rft.Tp)) {
return types.ETDatetime
}
return types.ETReal
Expand Down
19 changes: 14 additions & 5 deletions expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ func (s *testEvaluatorSuite) TestIntervalFunc(c *C) {
}
}

func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) {
// greatest/least function is compatible with MySQL 8.0
func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) {
sc := s.ctx.GetSessionVars().StmtCtx
originIgnoreTruncate := sc.IgnoreTruncate
sc.IgnoreTruncate = true
Expand All @@ -283,23 +284,23 @@ func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) {
},
{
[]interface{}{"123a", "b", "c", 12},
float64(123), float64(0), false, false,
"c", "12", false, false,
},
{
[]interface{}{tm, "123"},
curTimeString, "123", false, false,
},
{
[]interface{}{tm, 123},
curTimeInt, int64(123), false, false,
curTimeString, "123", false, false,
},
{
[]interface{}{tm, "invalid_time_1", "invalid_time_2", tmWithFsp},
curTimeWithFspString, "invalid_time_1", false, false,
curTimeWithFspString, curTimeString, false, false,
},
{
[]interface{}{tm, "invalid_time_2", "invalid_time_1", tmWithFsp},
curTimeWithFspString, "invalid_time_2", false, false,
curTimeWithFspString, curTimeString, false, false,
},
{
[]interface{}{tm, "invalid_time", nil, tmWithFsp},
Expand All @@ -317,6 +318,14 @@ func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) {
[]interface{}{errors.New("must error"), 123},
nil, nil, false, true,
},
{
[]interface{}{794755072.0, 4556, "2000-01-09"},
"794755072", "2000-01-09", false, false,
},
{
[]interface{}{905969664.0, 4556, "1990-06-16 17:22:56.005534"},
"905969664", "1990-06-16 17:22:56.005534", false, false,
},
} {
f0, err := newFunctionForTest(s.ctx, ast.Greatest, s.primitiveValsToConstants(t.args)...)
c.Assert(err, IsNil)
Expand Down
Loading

0 comments on commit dd0dc46

Please sign in to comment.