diff --git a/errors.toml b/errors.toml index 22b41340e989f..f9d455c7febd4 100644 --- a/errors.toml +++ b/errors.toml @@ -496,6 +496,16 @@ error = ''' Illegal mix of collations (%s,%s) and (%s,%s) for operation '%s' ''' +["expression:1270"] +error = ''' +Illegal mix of collations (%s,%s), (%s,%s), (%s,%s) for operation '%s' +''' + +["expression:1271"] +error = ''' +Illegal mix of collations for operation '%s' +''' + ["expression:1365"] error = ''' Division by 0 diff --git a/expression/builtin.go b/expression/builtin.go index a2cb13c045b98..aaf5b95a317a6 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -109,23 +109,45 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expressi return bf, nil } +var ( + // allowDeriveNoneFunction stores functions which allow two incompatible collations which have the same charset derive to CoercibilityNone + allowDeriveNoneFunction = map[string]struct{}{ + ast.Concat: {}, ast.ConcatWS: {}, ast.Reverse: {}, ast.Replace: {}, ast.InsertFunc: {}, ast.Lower: {}, + ast.Upper: {}, ast.Left: {}, ast.Right: {}, ast.Substr: {}, ast.SubstringIndex: {}, ast.Trim: {}, + ast.CurrentUser: {}, ast.Elt: {}, ast.MakeSet: {}, ast.Repeat: {}, ast.Rpad: {}, ast.Lpad: {}, + ast.ExportSet: {}, + } + + coerString = []string{"EXPLICIT", "NONE", "IMPLICIT", "SYSCONST", "COERCIBLE", "NUMERIC", "IGNORABLE"} +) + func checkIllegalMixCollation(funcName string, args []Expression) error { - firstExplicitCollation := "" - for _, arg := range args { - if arg.GetType().EvalType() != types.ETString { - continue - } - if arg.Coercibility() == CoercibilityExplicit { - if firstExplicitCollation == "" { - firstExplicitCollation = arg.GetType().Collate - } else if firstExplicitCollation != arg.GetType().Collate { - return collate.ErrIllegalMixCollation.GenWithStackByArgs(firstExplicitCollation, "EXPLICIT", arg.GetType().Collate, "EXPLICIT", funcName) - } + if len(args) < 2 { + return nil + } + _, _, coercibility, legal := inferCollation(args...) + if !legal { + return illegalMixCollationErr(funcName, args) + } + if coercibility == CoercibilityNone { + if _, ok := allowDeriveNoneFunction[funcName]; !ok { + return illegalMixCollationErr(funcName, args) } } return nil } +func illegalMixCollationErr(funcName string, args []Expression) error { + switch len(args) { + case 2: + return collate.ErrIllegalMix2Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], funcName) + case 3: + return collate.ErrIllegalMix3Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], args[0].GetType().Collate, coerString[args[2].Coercibility()], funcName) + default: + return collate.ErrIllegalMixCollation.GenWithStackByArgs(funcName) + } +} + // newBaseBuiltinFuncWithTp creates a built-in function signature with specified types of arguments and the return type of the function. // argTps indicates the types of the args, retType indicates the return type of the built-in function. // Every built-in function needs determined argTps and retType when we create it. diff --git a/expression/collation.go b/expression/collation.go index 492f7788a6941..e2acf5d6e86ef 100644 --- a/expression/collation.go +++ b/expression/collation.go @@ -14,12 +14,11 @@ package expression import ( - "strings" - "github.com/pingcap/parser/ast" "github.com/pingcap/parser/charset" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/logutil" ) type collationInfo struct { @@ -113,14 +112,17 @@ var ( } // collationPriority is the priority when infer the result collation, the priority of collation a > b iff collationPriority[a] > collationPriority[b] + // collation a and b are incompatible if collationPriority[a] = collationPriority[b] collationPriority = map[string]int{ - charset.CollationASCII: 0, - charset.CollationLatin1: 1, - "utf8_general_ci": 2, - charset.CollationUTF8: 3, - "utf8mb4_general_ci": 4, - charset.CollationUTF8MB4: 5, - charset.CollationBin: 6, + charset.CollationASCII: 1, + charset.CollationLatin1: 2, + "utf8_general_ci": 3, + "utf8_unicode_ci": 3, + charset.CollationUTF8: 4, + "utf8mb4_general_ci": 5, + "utf8mb4_unicode_ci": 5, + charset.CollationUTF8MB4: 6, + charset.CollationBin: 7, } // CollationStrictnessGroup group collation by strictness @@ -151,18 +153,13 @@ func deriveCoercibilityForScarlarFunc(sf *ScalarFunction) Coercibility { if sf.RetType.EvalType() != types.ETString { return CoercibilityNumeric } - coer := CoercibilityIgnorable - for _, arg := range sf.GetArgs() { - if arg.Coercibility() < coer { - coer = arg.Coercibility() - } - } + + _, _, coer, _ := inferCollation(sf.GetArgs()...) // it is weird if a ScalarFunction is CoercibilityNumeric but return string type if coer == CoercibilityNumeric { return CoercibilityCoercible } - return coer } @@ -184,42 +181,51 @@ func deriveCoercibilityForColumn(c *Column) Coercibility { // DeriveCollationFromExprs derives collation information from these expressions. func DeriveCollationFromExprs(ctx sessionctx.Context, exprs ...Expression) (dstCharset, dstCollation string) { - curCoer := CoercibilityIgnorable - curCollationPriority := -1 - dstCharset, dstCollation = charset.GetDefaultCharsetAndCollate() - if ctx != nil && ctx.GetSessionVars() != nil { - dstCharset, dstCollation = ctx.GetSessionVars().GetCharsetInfo() - if dstCharset == "" || dstCollation == "" { - dstCharset, dstCollation = charset.GetDefaultCharsetAndCollate() - } - } - hasStrArg := false - // see https://dev.mysql.com/doc/refman/8.0/en/charset-collation-coercibility.html - for _, e := range exprs { - if e.GetType().EvalType() != types.ETString { - continue - } - hasStrArg = true + dstCollation, dstCharset, _, _ = inferCollation(exprs...) + return +} - coer := e.Coercibility() - ft := e.GetType() - collationPriority, ok := collationPriority[strings.ToLower(ft.Collate)] - if !ok { - collationPriority = -1 - } - if coer != curCoer { - if coer < curCoer { - curCoer, curCollationPriority, dstCharset, dstCollation = coer, collationPriority, ft.Charset, ft.Collate +// inferCollation infers collation, charset, coercibility and check the legitimacy. +func inferCollation(exprs ...Expression) (dstCollation, dstCharset string, coercibility Coercibility, legal bool) { + firstExplicitCollation := "" + coercibility = CoercibilityIgnorable + dstCharset, dstCollation = charset.GetDefaultCharsetAndCollate() + for _, arg := range exprs { + if arg.Coercibility() == CoercibilityExplicit { + if firstExplicitCollation == "" { + firstExplicitCollation = arg.GetType().Collate + coercibility, dstCollation, dstCharset = CoercibilityExplicit, arg.GetType().Collate, arg.GetType().Charset + } else if firstExplicitCollation != arg.GetType().Collate { + return "", "", CoercibilityIgnorable, false + } + } else if arg.Coercibility() < coercibility { + coercibility, dstCollation, dstCharset = arg.Coercibility(), arg.GetType().Collate, arg.GetType().Charset + } else if arg.Coercibility() == coercibility && dstCollation != arg.GetType().Collate { + p1 := collationPriority[dstCollation] + p2 := collationPriority[arg.GetType().Collate] + + // same priority means this two collation is incompatible, coercibility might derive to CoercibilityNone + if p1 == p2 { + coercibility, dstCollation, dstCharset = CoercibilityNone, getBinCollation(arg.GetType().Charset), arg.GetType().Charset + } else if p1 < p2 { + dstCollation, dstCharset = arg.GetType().Collate, arg.GetType().Charset } - continue - } - if !ok || collationPriority <= curCollationPriority { - continue } - curCollationPriority, dstCharset, dstCollation = collationPriority, ft.Charset, ft.Collate } - if !hasStrArg { - dstCharset, dstCollation = charset.CharsetBin, charset.CollationBin + + return dstCollation, dstCharset, coercibility, true +} + +// getBinCollation get binary collation by charset +func getBinCollation(cs string) string { + switch cs { + case charset.CharsetUTF8: + return charset.CollationUTF8 + case charset.CharsetUTF8MB4: + return charset.CollationUTF8MB4 } - return + + logutil.BgLogger().Error("unexpected charset " + cs) + // it must return something, never reachable + return charset.CollationUTF8MB4 } diff --git a/expression/collation_test.go b/expression/collation_test.go index c18f7d882e779..698528baadce4 100644 --- a/expression/collation_test.go +++ b/expression/collation_test.go @@ -79,6 +79,7 @@ func (s *testCollationSuites) TestCompareString(c *C) { func (s *testCollationSuites) TestDeriveCollationFromExprs(c *C) { tInt := types.NewFieldType(mysql.TypeLonglong) + tInt.Charset = charset.CharsetBin ctx := mock.NewContext() // no string column diff --git a/expression/integration_test.go b/expression/integration_test.go index a8d6fb3d1e0a7..1cf2b6d72f931 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -6225,6 +6225,85 @@ func (s *testIntegrationSerialSuite) TestCollateConstantPropagation(c *C) { tk.MustExec("insert into t values ('a', 'a');") tk.MustQuery("select * from t t1, t t2 where t2.b = 'A' and lower(concat(t1.a , '' )) = t2.b;").Check(testkit.Rows("a a a a")) } + +func (s *testIntegrationSerialSuite) TestMixCollation(c *C) { + tk := testkit.NewTestKit(c, s.store) + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + + tk.MustGetErrMsg(`select 'a' collate utf8mb4_bin = 'a' collate utf8mb4_general_ci;`, "[expression:1267]Illegal mix of collations (utf8mb4_bin,EXPLICIT) and (utf8mb4_general_ci,EXPLICIT) for operation 'eq'") + + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec(`create table t ( + mb4general varchar(10) charset utf8mb4 collate utf8mb4_general_ci, + mb4unicode varchar(10) charset utf8mb4 collate utf8mb4_unicode_ci, + mb4bin varchar(10) charset utf8mb4 collate utf8mb4_bin, + general varchar(10) charset utf8 collate utf8_general_ci, + unicode varchar(10) charset utf8 collate utf8_unicode_ci, + utfbin varchar(10) charset utf8 collate utf8_bin, + bin varchar(10) charset binary collate binary, + latin1_bin varchar(10) charset latin1 collate latin1_bin, + ascii_bin varchar(10) charset ascii collate ascii_bin, + i int + );`) + tk.MustExec("insert into t values ('s', 's', 's', 's', 's', 's', 's', 's', 's', 1);") + tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci;") + + tk.MustQuery("select * from t where mb4unicode = 's' collate utf8mb4_unicode_ci;").Check(testkit.Rows("s s s s s s s s s 1")) + tk.MustQuery(`select * from t t1, t t2 where t1.mb4unicode = t2.mb4general collate utf8mb4_general_ci;`).Check(testkit.Rows("s s s s s s s s s 1 s s s s s s s s s 1")) + tk.MustQuery(`select * from t t1, t t2 where t1.mb4general = t2.mb4unicode collate utf8mb4_general_ci;`).Check(testkit.Rows("s s s s s s s s s 1 s s s s s s s s s 1")) + tk.MustQuery(`select * from t t1, t t2 where t1.mb4general = t2.mb4unicode collate utf8mb4_unicode_ci;`).Check(testkit.Rows("s s s s s s s s s 1 s s s s s s s s s 1")) + tk.MustQuery(`select * from t t1, t t2 where t1.mb4unicode = t2.mb4general collate utf8mb4_unicode_ci;`).Check(testkit.Rows("s s s s s s s s s 1 s s s s s s s s s 1")) + tk.MustQuery(`select * from t where mb4general = mb4bin collate utf8mb4_general_ci;`).Check(testkit.Rows("s s s s s s s s s 1")) + tk.MustQuery(`select * from t where mb4unicode = mb4general collate utf8mb4_unicode_ci;`).Check(testkit.Rows("s s s s s s s s s 1")) + tk.MustQuery(`select * from t where mb4general = mb4unicode collate utf8mb4_unicode_ci;`).Check(testkit.Rows("s s s s s s s s s 1")) + tk.MustQuery(`select * from t where mb4unicode = 's' collate utf8mb4_unicode_ci;`).Check(testkit.Rows("s s s s s s s s s 1")) + tk.MustQuery("select * from t where mb4unicode = mb4bin;").Check(testkit.Rows("s s s s s s s s s 1")) + tk.MustQuery("select * from t where general = mb4unicode;").Check(testkit.Rows("s s s s s s s s s 1")) + tk.MustQuery("select * from t where unicode = mb4unicode;").Check(testkit.Rows("s s s s s s s s s 1")) + tk.MustQuery("select * from t where mb4unicode = mb4unicode;").Check(testkit.Rows("s s s s s s s s s 1")) + + tk.MustQuery("select collation(concat(mb4unicode, mb4general collate utf8mb4_unicode_ci)) from t;").Check(testkit.Rows("utf8mb4_unicode_ci")) + tk.MustQuery("select collation(concat(mb4general, mb4unicode, mb4bin)) from t;").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select coercibility(concat(mb4general, mb4unicode, mb4bin)) from t;").Check(testkit.Rows("1")) + tk.MustQuery("select collation(concat(mb4unicode, mb4bin, concat(mb4general))) from t;").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select coercibility(concat(mb4unicode, mb4bin)) from t;").Check(testkit.Rows("2")) + tk.MustQuery("select collation(concat(mb4unicode, mb4bin)) from t;").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select coercibility(concat(mb4bin, concat(mb4general))) from t;").Check(testkit.Rows("2")) + tk.MustQuery("select collation(concaT(mb4bin, cOncAt(mb4general))) from t;").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select coercibility(concat(mb4unicode, mb4bin, concat(mb4general))) from t;").Check(testkit.Rows("2")) + tk.MustQuery("select collation(concat(mb4unicode, mb4bin, concat(mb4general))) from t;").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select coercibility(concat(mb4unicode, mb4general)) from t;").Check(testkit.Rows("1")) + tk.MustQuery("select collation(CONCAT(concat(mb4unicode), concat(mb4general))) from t;").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select coercibility(cONcat(unicode, general)) from t;").Check(testkit.Rows("1")) + tk.MustQuery("select collation(concAt(unicode, general)) from t;").Check(testkit.Rows("utf8_bin")) + tk.MustQuery("select collation(concat(bin, mb4general)) from t;").Check(testkit.Rows("binary")) + tk.MustQuery("select coercibility(concat(bin, mb4general)) from t;").Check(testkit.Rows("2")) + tk.MustQuery("select collation(concat(mb4unicode, ascii_bin)) from t;").Check(testkit.Rows("utf8mb4_unicode_ci")) + tk.MustQuery("select coercibility(concat(mb4unicode, ascii_bin)) from t;").Check(testkit.Rows("2")) + tk.MustQuery("select collation(concat(mb4unicode, mb4unicode)) from t;").Check(testkit.Rows("utf8mb4_unicode_ci")) + tk.MustQuery("select coercibility(concat(mb4unicode, mb4unicode)) from t;").Check(testkit.Rows("2")) + tk.MustQuery("select collation(concat(bin, bin)) from t;").Check(testkit.Rows("binary")) + tk.MustQuery("select coercibility(concat(bin, bin)) from t;").Check(testkit.Rows("2")) + tk.MustQuery("select collation(concat(latin1_bin, ascii_bin)) from t;").Check(testkit.Rows("latin1_bin")) + tk.MustQuery("select coercibility(concat(latin1_bin, ascii_bin)) from t;").Check(testkit.Rows("2")) + tk.MustQuery("select collation(concat(mb4unicode, bin)) from t;").Check(testkit.Rows("binary")) + tk.MustQuery("select coercibility(concat(mb4unicode, bin)) from t;").Check(testkit.Rows("2")) + tk.MustQuery("select collation(mb4general collate utf8mb4_unicode_ci) from t;").Check(testkit.Rows("utf8mb4_unicode_ci")) + tk.MustQuery("select coercibility(mb4general collate utf8mb4_unicode_ci) from t;").Check(testkit.Rows("0")) + tk.MustQuery("select collation(concat(concat(mb4unicode, mb4general), concat(unicode, general))) from t;").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select coercibility(concat(concat(mb4unicode, mb4general), concat(unicode, general))) from t;").Check(testkit.Rows("1")) + tk.MustGetErrMsg("select * from t where mb4unicode = mb4general;", "[expression:1267]Illegal mix of collations (utf8mb4_unicode_ci,IMPLICIT) and (utf8mb4_general_ci,IMPLICIT) for operation 'eq'") + tk.MustGetErrMsg("select * from t where unicode = general;", "[expression:1267]Illegal mix of collations (utf8_unicode_ci,IMPLICIT) and (utf8_general_ci,IMPLICIT) for operation 'eq'") + tk.MustGetErrMsg("select concat(mb4general) = concat(mb4unicode) from t;", "[expression:1267]Illegal mix of collations (utf8mb4_general_ci,IMPLICIT) and (utf8mb4_unicode_ci,IMPLICIT) for operation 'eq'") + tk.MustGetErrMsg("select * from t t1, t t2 where t1.mb4unicode = t2.mb4general;", "[expression:1267]Illegal mix of collations (utf8mb4_unicode_ci,IMPLICIT) and (utf8mb4_general_ci,IMPLICIT) for operation 'eq'") + tk.MustGetErrMsg("select field('s', mb4general, mb4unicode, mb4bin) from t;", "[expression:1271]Illegal mix of collations for operation 'field'") + tk.MustGetErrMsg("select concat(mb4unicode, mb4general) = mb4unicode from t;", "[expression:1267]Illegal mix of collations (utf8mb4_bin,NONE) and (utf8mb4_unicode_ci,IMPLICIT) for operation 'eq'") + + tk.MustExec("drop table t;") +} + func (s *testIntegrationSerialSuite) prepare4Join(c *C) *testkit.TestKit { tk := testkit.NewTestKit(c, s.store) tk.MustExec("USE test") diff --git a/util/collate/collate.go b/util/collate/collate.go index 32c73c3da617e..92e95dd614a1b 100644 --- a/util/collate/collate.go +++ b/util/collate/collate.go @@ -37,7 +37,11 @@ var ( // ErrUnsupportedCollation is returned when an unsupported collation is specified. ErrUnsupportedCollation = dbterror.ClassDDL.NewStdErr(mysql.ErrUnknownCollation, mysql.Message("Unsupported collation when new collation is enabled: '%-.64s'", nil)) // ErrIllegalMixCollation is returned when illegal mix of collations. - ErrIllegalMixCollation = dbterror.ClassExpression.NewStd(mysql.ErrCantAggregate2collations) + ErrIllegalMixCollation = dbterror.ClassExpression.NewStd(mysql.ErrCantAggregateNcollations) + // ErrIllegalMix2Collation is returned when illegal mix of 2 collations. + ErrIllegalMix2Collation = dbterror.ClassExpression.NewStd(mysql.ErrCantAggregate2collations) + // ErrIllegalMix3Collation is returned when illegal mix of 3 collations. + ErrIllegalMix3Collation = dbterror.ClassExpression.NewStd(mysql.ErrCantAggregate3collations) ) // DefaultLen is set for datum if the string datum don't know its length.