diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 1cce705a4ddc4..b5f45cf866fe2 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -15,6 +15,7 @@ package expression import ( "math" + "strings" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" @@ -367,53 +368,67 @@ func (b *builtinCoalesceJSONSig) evalJSON(row chunk.Row) (res json.BinaryJSON, i return res, isNull, err } -// temporalWithDateAsNumEvalType makes DATE, DATETIME, TIMESTAMP pretend to be numbers rather than strings. -func temporalWithDateAsNumEvalType(argTp *types.FieldType) (argEvalType types.EvalType, isStr bool, isTemporalWithDate bool) { - argEvalType = argTp.EvalType() - isStr, isTemporalWithDate = argEvalType.IsStringKind(), types.IsTemporalWithDate(argTp.Tp) - if !isTemporalWithDate { - return +func aggregateType(args []Expression) *types.FieldType { + fieldTypes := make([]*types.FieldType, len(args)) + for i := range fieldTypes { + fieldTypes[i] = args[i].GetType() } - if argTp.Decimal > 0 { - argEvalType = types.ETDecimal - } else { - argEvalType = types.ETInt - } - return + return types.AggFieldType(fieldTypes) } -// GetCmpTp4MinMax gets compare type for GREATEST and LEAST and BETWEEN -func GetCmpTp4MinMax(args []Expression) (argTp types.EvalType) { - datetimeFound, isAllStr := false, true - cmpEvalType, isStr, isTemporalWithDate := temporalWithDateAsNumEvalType(args[0].GetType()) - if !isStr { - isAllStr = false +// ResolveType4Between resolves eval type for between expression. +func ResolveType4Between(args [3]Expression) types.EvalType { + cmpTp := args[0].GetType().EvalType() + for i := 1; i < 3; i++ { + cmpTp = getBaseCmpType(cmpTp, args[i].GetType().EvalType(), nil, nil) } - if isTemporalWithDate { - datetimeFound = true - } - lft := args[0].GetType() - for i := range args { - rft := args[i].GetType() - var tp types.EvalType - tp, isStr, isTemporalWithDate = temporalWithDateAsNumEvalType(rft) - if isTemporalWithDate { - datetimeFound = true + + hasTemporal := false + if cmpTp == types.ETString { + for _, arg := range args { + if types.IsTypeTemporal(arg.GetType().Tp) { + hasTemporal = true + break + } } - if !isStr { - isAllStr = false + if hasTemporal { + cmpTp = types.ETDatetime } - cmpEvalType = getBaseCmpType(cmpEvalType, tp, lft, rft) - lft = rft } - argTp = cmpEvalType - if cmpEvalType.IsStringKind() { - argTp = types.ETString + + return cmpTp +} + +// resolveType4Extremum gets compare type for GREATEST and LEAST and BETWEEN (mainly for datetime). +func resolveType4Extremum(args []Expression) types.EvalType { + aggType := aggregateType(args) + + var temporalItem *types.FieldType + if aggType.EvalType().IsStringKind() { + for i := range args { + item := args[i].GetType() + if types.IsTemporalWithDate(item.Tp) { + temporalItem = item + } + } + + if !types.IsTemporalWithDate(aggType.Tp) && temporalItem != nil { + aggType.Tp = temporalItem.Tp + } + // TODO: String charset, collation checking are needed. } - if isAllStr && datetimeFound { - argTp = types.ETDatetime + return aggType.EvalType() +} + +// unsupportedJSONComparison reports warnings while there is a JSON type in least/greatest function's arguments +func unsupportedJSONComparison(ctx sessionctx.Context, args []Expression) { + for _, arg := range args { + tp := arg.GetType().Tp + if tp == mysql.TypeJSON { + ctx.GetSessionVars().StmtCtx.AppendWarning(errUnsupportedJSONComparison) + break + } } - return argTp } type greatestFunctionClass struct { @@ -424,10 +439,14 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre if err = c.verifyArgs(args); err != nil { return nil, err } - tp, cmpAsDatetime := GetCmpTp4MinMax(args), false - if tp == types.ETDatetime { + tp := resolveType4Extremum(args) + cmpAsDatetime := false + if tp == types.ETDatetime || tp == types.ETTimestamp { cmpAsDatetime = true tp = types.ETString + } else if tp == types.ETJson { + unsupportedJSONComparison(ctx, args) + tp = types.ETString } argTps := make([]types.EvalType, len(args)) for i := range args { @@ -453,7 +472,7 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre case types.ETString: sig = &builtinGreatestStringSig{bf} sig.setPbCode(tipb.ScalarFuncSig_GreatestString) - case types.ETDatetime: + case types.ETDatetime, types.ETTimestamp: sig = &builtinGreatestTimeSig{bf} sig.setPbCode(tipb.ScalarFuncSig_GreatestTime) } @@ -592,30 +611,39 @@ func (b *builtinGreatestTimeSig) Clone() builtinFunc { // evalString evals a builtinGreatestTimeSig. // See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_greatest -func (b *builtinGreatestTimeSig) evalString(row chunk.Row) (_ string, isNull bool, err error) { +func (b *builtinGreatestTimeSig) evalString(row chunk.Row) (res string, isNull bool, err error) { var ( - v string - t types.Time + strRes string + timeRes types.Time ) - max := types.ZeroDatetime sc := b.ctx.GetSessionVars().StmtCtx for i := 0; i < len(b.args); i++ { - v, isNull, err = b.args[i].EvalString(b.ctx, row) + v, isNull, err := b.args[i].EvalString(b.ctx, row) if isNull || err != nil { return "", true, err } - t, err = types.ParseDatetime(sc, v) + t, err := types.ParseDatetime(sc, v) if err != nil { if err = handleInvalidTimeError(b.ctx, err); err != nil { return v, true, err } - continue + } else { + v = t.String() + } + // In MySQL, if the compare result is zero, than we will try to use the string comparison result + if i == 0 || strings.Compare(v, strRes) > 0 { + strRes = v } - if t.Compare(max) > 0 { - max = t + if i == 0 || t.Compare(timeRes) > 0 { + timeRes = t } } - return max.String(), false, nil + if timeRes.IsZero() { + res = strRes + } else { + res = timeRes.String() + } + return res, false, nil } type leastFunctionClass struct { @@ -626,10 +654,14 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi if err = c.verifyArgs(args); err != nil { return nil, err } - tp, cmpAsDatetime := GetCmpTp4MinMax(args), false + tp := resolveType4Extremum(args) + cmpAsDatetime := false if tp == types.ETDatetime { cmpAsDatetime = true tp = types.ETString + } else if tp == types.ETJson { + unsupportedJSONComparison(ctx, args) + tp = types.ETString } argTps := make([]types.EvalType, len(args)) for i := range args { @@ -796,32 +828,36 @@ func (b *builtinLeastTimeSig) Clone() builtinFunc { // See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#functionleast func (b *builtinLeastTimeSig) evalString(row chunk.Row) (res string, isNull bool, err error) { var ( - v string - t types.Time + // timeRes will be converted to a strRes only when the arguments is a valid datetime value. + strRes string // Record the strRes of each arguments. + timeRes types.Time // Record the time representation of a valid arguments. ) - min := types.NewTime(types.MaxDatetime, mysql.TypeDatetime, types.MaxFsp) - findInvalidTime := false sc := b.ctx.GetSessionVars().StmtCtx for i := 0; i < len(b.args); i++ { - v, isNull, err = b.args[i].EvalString(b.ctx, row) + v, isNull, err := b.args[i].EvalString(b.ctx, row) if isNull || err != nil { return "", true, err } - t, err = types.ParseDatetime(sc, v) + t, err := types.ParseDatetime(sc, v) if err != nil { if err = handleInvalidTimeError(b.ctx, err); err != nil { return v, true, err - } else if !findInvalidTime { - res = v - findInvalidTime = true } + } else { + v = t.String() + } + if i == 0 || strings.Compare(v, strRes) < 0 { + strRes = v } - if t.Compare(min) < 0 { - min = t + if i == 0 || t.Compare(timeRes) < 0 { + timeRes = t } } - if !findInvalidTime { - res = min.String() + + if timeRes.IsZero() { + res = strRes + } else { + res = timeRes.String() } return res, false, nil } @@ -1042,7 +1078,7 @@ type compareFunctionClass struct { // getBaseCmpType gets the EvalType that the two args will be treated as when comparing. func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.EvalType { - if lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified { + if lft != nil && rft != nil && (lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified) { if lft.Tp == rft.Tp { return types.ETString } @@ -1054,11 +1090,14 @@ func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.Ev } if lhs.IsStringKind() && rhs.IsStringKind() { return types.ETString - } else if (lhs == types.ETInt || lft.Hybrid()) && (rhs == types.ETInt || rft.Hybrid()) { + } else if (lhs == types.ETInt || (lft != nil && lft.Hybrid())) && (rhs == types.ETInt || (rft != nil && rft.Hybrid())) { return types.ETInt - } else if ((lhs == types.ETInt || lft.Hybrid()) || lhs == types.ETDecimal) && - ((rhs == types.ETInt || rft.Hybrid()) || rhs == types.ETDecimal) { + } else if ((lhs == types.ETInt || (lft != nil && lft.Hybrid())) || lhs == types.ETDecimal) && + ((rhs == types.ETInt || (rft != nil && rft.Hybrid())) || rhs == types.ETDecimal) { return types.ETDecimal + } else if lft != nil && rft != nil && (types.IsTemporalWithDate(lft.Tp) && rft.Tp == mysql.TypeYear || + lft.Tp == mysql.TypeYear && types.IsTemporalWithDate(rft.Tp)) { + return types.ETDatetime } return types.ETReal } diff --git a/expression/builtin_compare_test.go b/expression/builtin_compare_test.go index cc1703056917f..bf2ac4cdb65fd 100644 --- a/expression/builtin_compare_test.go +++ b/expression/builtin_compare_test.go @@ -257,7 +257,8 @@ func (s *testEvaluatorSuite) TestIntervalFunc(c *C) { } } -func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) { +// greatest/least function is compatible with MySQL 8.0 +func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) { sc := s.ctx.GetSessionVars().StmtCtx originIgnoreTruncate := sc.IgnoreTruncate sc.IgnoreTruncate = true @@ -282,7 +283,7 @@ func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) { }, { []interface{}{"123a", "b", "c", 12}, - float64(123), float64(0), false, false, + "c", "12", false, false, }, { []interface{}{tm, "123"}, @@ -290,15 +291,15 @@ func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) { }, { []interface{}{tm, 123}, - curTimeInt, int64(123), false, false, + curTimeString, "123", false, false, }, { []interface{}{tm, "invalid_time_1", "invalid_time_2", tmWithFsp}, - curTimeWithFspString, "invalid_time_1", false, false, + curTimeWithFspString, curTimeString, false, false, }, { []interface{}{tm, "invalid_time_2", "invalid_time_1", tmWithFsp}, - curTimeWithFspString, "invalid_time_2", false, false, + curTimeWithFspString, curTimeString, false, false, }, { []interface{}{tm, "invalid_time", nil, tmWithFsp}, @@ -316,6 +317,14 @@ func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) { []interface{}{errors.New("must error"), 123}, nil, nil, false, true, }, + { + []interface{}{794755072.0, 4556, "2000-01-09"}, + "794755072", "2000-01-09", false, false, + }, + { + []interface{}{905969664.0, 4556, "1990-06-16 17:22:56.005534"}, + "905969664", "1990-06-16 17:22:56.005534", false, false, + }, } { f0, err := newFunctionForTest(s.ctx, ast.Greatest, s.primitiveValsToConstants(t.args)...) c.Assert(err, IsNil) diff --git a/expression/builtin_compare_vec.go b/expression/builtin_compare_vec.go index b55d676ae2b0c..09343305faea1 100644 --- a/expression/builtin_compare_vec.go +++ b/expression/builtin_compare_vec.go @@ -14,6 +14,8 @@ package expression import ( + "strings" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -633,47 +635,46 @@ func (b *builtinGreatestTimeSig) vectorized() bool { } func (b *builtinGreatestTimeSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { + sc := b.ctx.GetSessionVars().StmtCtx n := input.NumRows() - dst, err := b.bufAllocator.get(types.ETTimestamp, n) - if err != nil { - return err - } - defer b.bufAllocator.put(dst) - sc := b.ctx.GetSessionVars().StmtCtx - dst.ResizeTime(n, false) - dstTimes := dst.Times() - for i := 0; i < n; i++ { - dstTimes[i] = types.ZeroDatetime - } - var argTime types.Time + dstStrings := make([]string, n) + // TODO: use Column.MergeNulls instead, however, it doesn't support var-length type currently. + dstNullMap := make([]bool, n) + for j := 0; j < len(b.args); j++ { if err := b.args[j].VecEvalString(b.ctx, input, result); err != nil { return err } for i := 0; i < n; i++ { - if result.IsNull(i) || dst.IsNull(i) { - dst.SetNull(i, true) + if dstNullMap[i] = dstNullMap[i] || result.IsNull(i); dstNullMap[i] { continue } - argTime, err = types.ParseDatetime(sc, result.GetString(i)) + + // NOTE: can't use Column.GetString because it returns an unsafe string, copy the row instead. + argTimeStr := string(result.GetBytes(i)) + + argTime, err := types.ParseDatetime(sc, argTimeStr) if err != nil { if err = handleInvalidTimeError(b.ctx, err); err != nil { return err } - continue + } else { + argTimeStr = argTime.String() } - if argTime.Compare(dstTimes[i]) > 0 { - dstTimes[i] = argTime + if j == 0 || strings.Compare(argTimeStr, dstStrings[i]) > 0 { + dstStrings[i] = argTimeStr } } } + + // Aggregate the NULL and String value into result result.ReserveString(n) for i := 0; i < n; i++ { - if dst.IsNull(i) { + if dstNullMap[i] { result.AppendNull() } else { - result.AppendString(dstTimes[i].String()) + result.AppendString(dstStrings[i]) } } return nil @@ -719,60 +720,46 @@ func (b *builtinLeastTimeSig) vectorized() bool { } func (b *builtinLeastTimeSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { - n := input.NumRows() - dst, err := b.bufAllocator.get(types.ETTimestamp, n) - if err != nil { - return err - } - defer b.bufAllocator.put(dst) - sc := b.ctx.GetSessionVars().StmtCtx - dst.ResizeTime(n, false) - dstTimes := dst.Times() - for i := 0; i < n; i++ { - dstTimes[i] = types.NewTime(types.MaxDatetime, mysql.TypeDatetime, types.DefaultFsp) - } - var argTime types.Time + n := input.NumRows() - var findInvalidTime []bool = make([]bool, n) - var invalidValue []string = make([]string, n) + dstStrings := make([]string, n) + // TODO: use Column.MergeNulls instead, however, it doesn't support var-length type currently. + dstNullMap := make([]bool, n) for j := 0; j < len(b.args); j++ { if err := b.args[j].VecEvalString(b.ctx, input, result); err != nil { return err } - dst.MergeNulls(result) for i := 0; i < n; i++ { - if dst.IsNull(i) { + if dstNullMap[i] = dstNullMap[i] || result.IsNull(i); dstNullMap[i] { continue } - argTime, err = types.ParseDatetime(sc, result.GetString(i)) + + // NOTE: can't use Column.GetString because it returns an unsafe string, copy the row instead. + argTimeStr := string(result.GetBytes(i)) + + argTime, err := types.ParseDatetime(sc, argTimeStr) if err != nil { if err = handleInvalidTimeError(b.ctx, err); err != nil { return err - } else if !findInvalidTime[i] { - // Make a deep copy here. - // Otherwise invalidValue will internally change with result. - invalidValue[i] = string(result.GetBytes(i)) - findInvalidTime[i] = true } - continue + } else { + argTimeStr = argTime.String() } - if argTime.Compare(dstTimes[i]) < 0 { - dstTimes[i] = argTime + if j == 0 || strings.Compare(argTimeStr, dstStrings[i]) < 0 { + dstStrings[i] = argTimeStr } } } + + // Aggregate the NULL and String value into result result.ReserveString(n) for i := 0; i < n; i++ { - if dst.IsNull(i) { + if dstNullMap[i] { result.AppendNull() - continue - } - if findInvalidTime[i] { - result.AppendString(invalidValue[i]) } else { - result.AppendString(dstTimes[i].String()) + result.AppendString(dstStrings[i]) } } return nil diff --git a/expression/errors.go b/expression/errors.go index e3ae09e93fe0a..10b524f19c8c1 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -49,7 +49,9 @@ var ( errNonUniq = dbterror.ClassExpression.NewStd(mysql.ErrNonUniq) // Sequence usage privilege check. - errSequenceAccessDenied = dbterror.ClassExpression.NewStd(mysql.ErrTableaccessDenied) + errSequenceAccessDenied = dbterror.ClassExpression.NewStd(mysql.ErrTableaccessDenied) + errUnsupportedJSONComparison = dbterror.ClassExpression.NewStdErr(mysql.ErrNotSupportedYet, + pmysql.Message("comparison of JSON in the LEAST and GREATEST operators", nil)) ) // handleInvalidTimeError reports error or warning depend on the context. diff --git a/expression/integration_test.go b/expression/integration_test.go index 5d98713da7699..55aff0acdb7a9 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -3656,8 +3656,8 @@ func (s *testIntegrationSuite) TestCompareBuiltin(c *C) { // for greatest result = tk.MustQuery(`select greatest(1, 2, 3), greatest("a", "b", "c"), greatest(1.1, 1.2, 1.3), greatest("123a", 1, 2)`) - result.Check(testkit.Rows("3 c 1.3 123")) - tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect FLOAT value: '123a'")) + result.Check(testkit.Rows("3 c 1.3 2")) + tk.MustQuery("show warnings").Check(testkit.Rows()) result = tk.MustQuery(`select greatest(cast("2017-01-01" as datetime), "123", "234", cast("2018-01-01" as date)), greatest(cast("2017-01-01" as date), "123", null)`) // todo: MySQL returns "2018-01-01 " result.Check(testkit.Rows("2018-01-01 00:00:00 ")) @@ -3665,7 +3665,7 @@ func (s *testIntegrationSuite) TestCompareBuiltin(c *C) { // for least result = tk.MustQuery(`select least(1, 2, 3), least("a", "b", "c"), least(1.1, 1.2, 1.3), least("123a", 1, 2)`) result.Check(testkit.Rows("1 a 1.1 1")) - tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect FLOAT value: '123a'")) + tk.MustQuery("show warnings").Check(testkit.Rows()) result = tk.MustQuery(`select least(cast("2017-01-01" as datetime), "123", "234", cast("2018-01-01" as date)), least(cast("2017-01-01" as date), "123", null)`) result.Check(testkit.Rows("123 ")) tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Incorrect time value: '123'", "Warning|1292|Incorrect time value: '234'", "Warning|1292|Incorrect time value: '123'")) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 4b548211a6c6b..f66e9b614e654 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1512,7 +1512,7 @@ func (er *expressionRewriter) wrapExpWithCast() (expr, lexp, rexp expression.Exp stkLen := len(er.ctxStack) expr, lexp, rexp = er.ctxStack[stkLen-3], er.ctxStack[stkLen-2], er.ctxStack[stkLen-1] var castFunc func(sessionctx.Context, expression.Expression) expression.Expression - switch expression.GetCmpTp4MinMax([]expression.Expression{expr, lexp, rexp}) { + switch expression.ResolveType4Between([3]expression.Expression{expr, lexp, rexp}) { case types.ETInt: castFunc = expression.WrapWithCastAsInt case types.ETReal: diff --git a/planner/core/expression_test.go b/planner/core/expression_test.go index 0f19b800a52ed..4461db382b47a 100644 --- a/planner/core/expression_test.go +++ b/planner/core/expression_test.go @@ -72,6 +72,7 @@ func (s *testExpressionSuite) TestBetween(c *C) { {exprStr: "1 not between 2 and 3", resultStr: "1"}, {exprStr: "'2001-04-10 12:34:56' between cast('2001-01-01 01:01:01' as datetime) and '01-05-01'", resultStr: "1"}, {exprStr: "20010410123456 between cast('2001-01-01 01:01:01' as datetime) and 010501", resultStr: "0"}, + {exprStr: "20010410123456 between cast('2001-01-01 01:01:01' as datetime) and 20010501123456", resultStr: "1"}, } s.runTests(c, tests) } diff --git a/types/field_type.go b/types/field_type.go index cb7d43ebbaaaf..b93147564c5fe 100644 --- a/types/field_type.go +++ b/types/field_type.go @@ -98,6 +98,10 @@ func AggFieldType(tps []*FieldType) *FieldType { } } + if mysql.HasUnsignedFlag(currType.Flag) && !isMixedSign { + currType.Flag |= mysql.UnsignedFlag + } + return &currType }