Skip to content

Commit

Permalink
expression: unicode_ci support when infer collation and charset infor…
Browse files Browse the repository at this point in the history
…mation (#19142) (#22581)

Signed-off-by: ti-srebot <ti-srebot@pingcap.com>
  • Loading branch information
ti-srebot authored Jan 28, 2021
1 parent 33fdb15 commit 187ef72
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 61 deletions.
10 changes: 10 additions & 0 deletions errors.toml
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,16 @@ error = '''
Illegal mix of collations (%s,%s) and (%s,%s) for operation '%s'
'''

["expression:1270"]
error = '''
Illegal mix of collations (%s,%s), (%s,%s), (%s,%s) for operation '%s'
'''

["expression:1271"]
error = '''
Illegal mix of collations for operation '%s'
'''

["expression:1365"]
error = '''
Division by 0
Expand Down
44 changes: 33 additions & 11 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,45 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expressi
return bf, nil
}

var (
// allowDeriveNoneFunction stores functions which allow two incompatible collations which have the same charset derive to CoercibilityNone
allowDeriveNoneFunction = map[string]struct{}{
ast.Concat: {}, ast.ConcatWS: {}, ast.Reverse: {}, ast.Replace: {}, ast.InsertFunc: {}, ast.Lower: {},
ast.Upper: {}, ast.Left: {}, ast.Right: {}, ast.Substr: {}, ast.SubstringIndex: {}, ast.Trim: {},
ast.CurrentUser: {}, ast.Elt: {}, ast.MakeSet: {}, ast.Repeat: {}, ast.Rpad: {}, ast.Lpad: {},
ast.ExportSet: {},
}

coerString = []string{"EXPLICIT", "NONE", "IMPLICIT", "SYSCONST", "COERCIBLE", "NUMERIC", "IGNORABLE"}
)

func checkIllegalMixCollation(funcName string, args []Expression) error {
firstExplicitCollation := ""
for _, arg := range args {
if arg.GetType().EvalType() != types.ETString {
continue
}
if arg.Coercibility() == CoercibilityExplicit {
if firstExplicitCollation == "" {
firstExplicitCollation = arg.GetType().Collate
} else if firstExplicitCollation != arg.GetType().Collate {
return collate.ErrIllegalMixCollation.GenWithStackByArgs(firstExplicitCollation, "EXPLICIT", arg.GetType().Collate, "EXPLICIT", funcName)
}
if len(args) < 2 {
return nil
}
_, _, coercibility, legal := inferCollation(args...)
if !legal {
return illegalMixCollationErr(funcName, args)
}
if coercibility == CoercibilityNone {
if _, ok := allowDeriveNoneFunction[funcName]; !ok {
return illegalMixCollationErr(funcName, args)
}
}
return nil
}

func illegalMixCollationErr(funcName string, args []Expression) error {
switch len(args) {
case 2:
return collate.ErrIllegalMix2Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], funcName)
case 3:
return collate.ErrIllegalMix3Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], args[0].GetType().Collate, coerString[args[2].Coercibility()], funcName)
default:
return collate.ErrIllegalMixCollation.GenWithStackByArgs(funcName)
}
}

// newBaseBuiltinFuncWithTp creates a built-in function signature with specified types of arguments and the return type of the function.
// argTps indicates the types of the args, retType indicates the return type of the built-in function.
// Every built-in function needs determined argTps and retType when we create it.
Expand Down
104 changes: 55 additions & 49 deletions expression/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
package expression

import (
"strings"

"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/charset"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/logutil"
)

type collationInfo struct {
Expand Down Expand Up @@ -113,14 +112,17 @@ var (
}

// collationPriority is the priority when infer the result collation, the priority of collation a > b iff collationPriority[a] > collationPriority[b]
// collation a and b are incompatible if collationPriority[a] = collationPriority[b]
collationPriority = map[string]int{
charset.CollationASCII: 0,
charset.CollationLatin1: 1,
"utf8_general_ci": 2,
charset.CollationUTF8: 3,
"utf8mb4_general_ci": 4,
charset.CollationUTF8MB4: 5,
charset.CollationBin: 6,
charset.CollationASCII: 1,
charset.CollationLatin1: 2,
"utf8_general_ci": 3,
"utf8_unicode_ci": 3,
charset.CollationUTF8: 4,
"utf8mb4_general_ci": 5,
"utf8mb4_unicode_ci": 5,
charset.CollationUTF8MB4: 6,
charset.CollationBin: 7,
}

// CollationStrictnessGroup group collation by strictness
Expand Down Expand Up @@ -151,18 +153,13 @@ func deriveCoercibilityForScarlarFunc(sf *ScalarFunction) Coercibility {
if sf.RetType.EvalType() != types.ETString {
return CoercibilityNumeric
}
coer := CoercibilityIgnorable
for _, arg := range sf.GetArgs() {
if arg.Coercibility() < coer {
coer = arg.Coercibility()
}
}

_, _, coer, _ := inferCollation(sf.GetArgs()...)

// it is weird if a ScalarFunction is CoercibilityNumeric but return string type
if coer == CoercibilityNumeric {
return CoercibilityCoercible
}

return coer
}

Expand All @@ -184,42 +181,51 @@ func deriveCoercibilityForColumn(c *Column) Coercibility {

// DeriveCollationFromExprs derives collation information from these expressions.
func DeriveCollationFromExprs(ctx sessionctx.Context, exprs ...Expression) (dstCharset, dstCollation string) {
curCoer := CoercibilityIgnorable
curCollationPriority := -1
dstCharset, dstCollation = charset.GetDefaultCharsetAndCollate()
if ctx != nil && ctx.GetSessionVars() != nil {
dstCharset, dstCollation = ctx.GetSessionVars().GetCharsetInfo()
if dstCharset == "" || dstCollation == "" {
dstCharset, dstCollation = charset.GetDefaultCharsetAndCollate()
}
}
hasStrArg := false
// see https://dev.mysql.com/doc/refman/8.0/en/charset-collation-coercibility.html
for _, e := range exprs {
if e.GetType().EvalType() != types.ETString {
continue
}
hasStrArg = true
dstCollation, dstCharset, _, _ = inferCollation(exprs...)
return
}

coer := e.Coercibility()
ft := e.GetType()
collationPriority, ok := collationPriority[strings.ToLower(ft.Collate)]
if !ok {
collationPriority = -1
}
if coer != curCoer {
if coer < curCoer {
curCoer, curCollationPriority, dstCharset, dstCollation = coer, collationPriority, ft.Charset, ft.Collate
// inferCollation infers collation, charset, coercibility and check the legitimacy.
func inferCollation(exprs ...Expression) (dstCollation, dstCharset string, coercibility Coercibility, legal bool) {
firstExplicitCollation := ""
coercibility = CoercibilityIgnorable
dstCharset, dstCollation = charset.GetDefaultCharsetAndCollate()
for _, arg := range exprs {
if arg.Coercibility() == CoercibilityExplicit {
if firstExplicitCollation == "" {
firstExplicitCollation = arg.GetType().Collate
coercibility, dstCollation, dstCharset = CoercibilityExplicit, arg.GetType().Collate, arg.GetType().Charset
} else if firstExplicitCollation != arg.GetType().Collate {
return "", "", CoercibilityIgnorable, false
}
} else if arg.Coercibility() < coercibility {
coercibility, dstCollation, dstCharset = arg.Coercibility(), arg.GetType().Collate, arg.GetType().Charset
} else if arg.Coercibility() == coercibility && dstCollation != arg.GetType().Collate {
p1 := collationPriority[dstCollation]
p2 := collationPriority[arg.GetType().Collate]

// same priority means this two collation is incompatible, coercibility might derive to CoercibilityNone
if p1 == p2 {
coercibility, dstCollation, dstCharset = CoercibilityNone, getBinCollation(arg.GetType().Charset), arg.GetType().Charset
} else if p1 < p2 {
dstCollation, dstCharset = arg.GetType().Collate, arg.GetType().Charset
}
continue
}
if !ok || collationPriority <= curCollationPriority {
continue
}
curCollationPriority, dstCharset, dstCollation = collationPriority, ft.Charset, ft.Collate
}
if !hasStrArg {
dstCharset, dstCollation = charset.CharsetBin, charset.CollationBin

return dstCollation, dstCharset, coercibility, true
}

// getBinCollation get binary collation by charset
func getBinCollation(cs string) string {
switch cs {
case charset.CharsetUTF8:
return charset.CollationUTF8
case charset.CharsetUTF8MB4:
return charset.CollationUTF8MB4
}
return

logutil.BgLogger().Error("unexpected charset " + cs)
// it must return something, never reachable
return charset.CollationUTF8MB4
}
1 change: 1 addition & 0 deletions expression/collation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func (s *testCollationSuites) TestCompareString(c *C) {

func (s *testCollationSuites) TestDeriveCollationFromExprs(c *C) {
tInt := types.NewFieldType(mysql.TypeLonglong)
tInt.Charset = charset.CharsetBin
ctx := mock.NewContext()

// no string column
Expand Down
79 changes: 79 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6225,6 +6225,85 @@ func (s *testIntegrationSerialSuite) TestCollateConstantPropagation(c *C) {
tk.MustExec("insert into t values ('a', 'a');")
tk.MustQuery("select * from t t1, t t2 where t2.b = 'A' and lower(concat(t1.a , '' )) = t2.b;").Check(testkit.Rows("a a a a"))
}

func (s *testIntegrationSerialSuite) TestMixCollation(c *C) {
tk := testkit.NewTestKit(c, s.store)
collate.SetNewCollationEnabledForTest(true)
defer collate.SetNewCollationEnabledForTest(false)

tk.MustGetErrMsg(`select 'a' collate utf8mb4_bin = 'a' collate utf8mb4_general_ci;`, "[expression:1267]Illegal mix of collations (utf8mb4_bin,EXPLICIT) and (utf8mb4_general_ci,EXPLICIT) for operation 'eq'")

tk.MustExec("use test;")
tk.MustExec("drop table if exists t;")
tk.MustExec(`create table t (
mb4general varchar(10) charset utf8mb4 collate utf8mb4_general_ci,
mb4unicode varchar(10) charset utf8mb4 collate utf8mb4_unicode_ci,
mb4bin varchar(10) charset utf8mb4 collate utf8mb4_bin,
general varchar(10) charset utf8 collate utf8_general_ci,
unicode varchar(10) charset utf8 collate utf8_unicode_ci,
utfbin varchar(10) charset utf8 collate utf8_bin,
bin varchar(10) charset binary collate binary,
latin1_bin varchar(10) charset latin1 collate latin1_bin,
ascii_bin varchar(10) charset ascii collate ascii_bin,
i int
);`)
tk.MustExec("insert into t values ('s', 's', 's', 's', 's', 's', 's', 's', 's', 1);")
tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci;")

tk.MustQuery("select * from t where mb4unicode = 's' collate utf8mb4_unicode_ci;").Check(testkit.Rows("s s s s s s s s s 1"))
tk.MustQuery(`select * from t t1, t t2 where t1.mb4unicode = t2.mb4general collate utf8mb4_general_ci;`).Check(testkit.Rows("s s s s s s s s s 1 s s s s s s s s s 1"))
tk.MustQuery(`select * from t t1, t t2 where t1.mb4general = t2.mb4unicode collate utf8mb4_general_ci;`).Check(testkit.Rows("s s s s s s s s s 1 s s s s s s s s s 1"))
tk.MustQuery(`select * from t t1, t t2 where t1.mb4general = t2.mb4unicode collate utf8mb4_unicode_ci;`).Check(testkit.Rows("s s s s s s s s s 1 s s s s s s s s s 1"))
tk.MustQuery(`select * from t t1, t t2 where t1.mb4unicode = t2.mb4general collate utf8mb4_unicode_ci;`).Check(testkit.Rows("s s s s s s s s s 1 s s s s s s s s s 1"))
tk.MustQuery(`select * from t where mb4general = mb4bin collate utf8mb4_general_ci;`).Check(testkit.Rows("s s s s s s s s s 1"))
tk.MustQuery(`select * from t where mb4unicode = mb4general collate utf8mb4_unicode_ci;`).Check(testkit.Rows("s s s s s s s s s 1"))
tk.MustQuery(`select * from t where mb4general = mb4unicode collate utf8mb4_unicode_ci;`).Check(testkit.Rows("s s s s s s s s s 1"))
tk.MustQuery(`select * from t where mb4unicode = 's' collate utf8mb4_unicode_ci;`).Check(testkit.Rows("s s s s s s s s s 1"))
tk.MustQuery("select * from t where mb4unicode = mb4bin;").Check(testkit.Rows("s s s s s s s s s 1"))
tk.MustQuery("select * from t where general = mb4unicode;").Check(testkit.Rows("s s s s s s s s s 1"))
tk.MustQuery("select * from t where unicode = mb4unicode;").Check(testkit.Rows("s s s s s s s s s 1"))
tk.MustQuery("select * from t where mb4unicode = mb4unicode;").Check(testkit.Rows("s s s s s s s s s 1"))

tk.MustQuery("select collation(concat(mb4unicode, mb4general collate utf8mb4_unicode_ci)) from t;").Check(testkit.Rows("utf8mb4_unicode_ci"))
tk.MustQuery("select collation(concat(mb4general, mb4unicode, mb4bin)) from t;").Check(testkit.Rows("utf8mb4_bin"))
tk.MustQuery("select coercibility(concat(mb4general, mb4unicode, mb4bin)) from t;").Check(testkit.Rows("1"))
tk.MustQuery("select collation(concat(mb4unicode, mb4bin, concat(mb4general))) from t;").Check(testkit.Rows("utf8mb4_bin"))
tk.MustQuery("select coercibility(concat(mb4unicode, mb4bin)) from t;").Check(testkit.Rows("2"))
tk.MustQuery("select collation(concat(mb4unicode, mb4bin)) from t;").Check(testkit.Rows("utf8mb4_bin"))
tk.MustQuery("select coercibility(concat(mb4bin, concat(mb4general))) from t;").Check(testkit.Rows("2"))
tk.MustQuery("select collation(concaT(mb4bin, cOncAt(mb4general))) from t;").Check(testkit.Rows("utf8mb4_bin"))
tk.MustQuery("select coercibility(concat(mb4unicode, mb4bin, concat(mb4general))) from t;").Check(testkit.Rows("2"))
tk.MustQuery("select collation(concat(mb4unicode, mb4bin, concat(mb4general))) from t;").Check(testkit.Rows("utf8mb4_bin"))
tk.MustQuery("select coercibility(concat(mb4unicode, mb4general)) from t;").Check(testkit.Rows("1"))
tk.MustQuery("select collation(CONCAT(concat(mb4unicode), concat(mb4general))) from t;").Check(testkit.Rows("utf8mb4_bin"))
tk.MustQuery("select coercibility(cONcat(unicode, general)) from t;").Check(testkit.Rows("1"))
tk.MustQuery("select collation(concAt(unicode, general)) from t;").Check(testkit.Rows("utf8_bin"))
tk.MustQuery("select collation(concat(bin, mb4general)) from t;").Check(testkit.Rows("binary"))
tk.MustQuery("select coercibility(concat(bin, mb4general)) from t;").Check(testkit.Rows("2"))
tk.MustQuery("select collation(concat(mb4unicode, ascii_bin)) from t;").Check(testkit.Rows("utf8mb4_unicode_ci"))
tk.MustQuery("select coercibility(concat(mb4unicode, ascii_bin)) from t;").Check(testkit.Rows("2"))
tk.MustQuery("select collation(concat(mb4unicode, mb4unicode)) from t;").Check(testkit.Rows("utf8mb4_unicode_ci"))
tk.MustQuery("select coercibility(concat(mb4unicode, mb4unicode)) from t;").Check(testkit.Rows("2"))
tk.MustQuery("select collation(concat(bin, bin)) from t;").Check(testkit.Rows("binary"))
tk.MustQuery("select coercibility(concat(bin, bin)) from t;").Check(testkit.Rows("2"))
tk.MustQuery("select collation(concat(latin1_bin, ascii_bin)) from t;").Check(testkit.Rows("latin1_bin"))
tk.MustQuery("select coercibility(concat(latin1_bin, ascii_bin)) from t;").Check(testkit.Rows("2"))
tk.MustQuery("select collation(concat(mb4unicode, bin)) from t;").Check(testkit.Rows("binary"))
tk.MustQuery("select coercibility(concat(mb4unicode, bin)) from t;").Check(testkit.Rows("2"))
tk.MustQuery("select collation(mb4general collate utf8mb4_unicode_ci) from t;").Check(testkit.Rows("utf8mb4_unicode_ci"))
tk.MustQuery("select coercibility(mb4general collate utf8mb4_unicode_ci) from t;").Check(testkit.Rows("0"))
tk.MustQuery("select collation(concat(concat(mb4unicode, mb4general), concat(unicode, general))) from t;").Check(testkit.Rows("utf8mb4_bin"))
tk.MustQuery("select coercibility(concat(concat(mb4unicode, mb4general), concat(unicode, general))) from t;").Check(testkit.Rows("1"))
tk.MustGetErrMsg("select * from t where mb4unicode = mb4general;", "[expression:1267]Illegal mix of collations (utf8mb4_unicode_ci,IMPLICIT) and (utf8mb4_general_ci,IMPLICIT) for operation 'eq'")
tk.MustGetErrMsg("select * from t where unicode = general;", "[expression:1267]Illegal mix of collations (utf8_unicode_ci,IMPLICIT) and (utf8_general_ci,IMPLICIT) for operation 'eq'")
tk.MustGetErrMsg("select concat(mb4general) = concat(mb4unicode) from t;", "[expression:1267]Illegal mix of collations (utf8mb4_general_ci,IMPLICIT) and (utf8mb4_unicode_ci,IMPLICIT) for operation 'eq'")
tk.MustGetErrMsg("select * from t t1, t t2 where t1.mb4unicode = t2.mb4general;", "[expression:1267]Illegal mix of collations (utf8mb4_unicode_ci,IMPLICIT) and (utf8mb4_general_ci,IMPLICIT) for operation 'eq'")
tk.MustGetErrMsg("select field('s', mb4general, mb4unicode, mb4bin) from t;", "[expression:1271]Illegal mix of collations for operation 'field'")
tk.MustGetErrMsg("select concat(mb4unicode, mb4general) = mb4unicode from t;", "[expression:1267]Illegal mix of collations (utf8mb4_bin,NONE) and (utf8mb4_unicode_ci,IMPLICIT) for operation 'eq'")

tk.MustExec("drop table t;")
}

func (s *testIntegrationSerialSuite) prepare4Join(c *C) *testkit.TestKit {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("USE test")
Expand Down
6 changes: 5 additions & 1 deletion util/collate/collate.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ var (
// ErrUnsupportedCollation is returned when an unsupported collation is specified.
ErrUnsupportedCollation = dbterror.ClassDDL.NewStdErr(mysql.ErrUnknownCollation, mysql.Message("Unsupported collation when new collation is enabled: '%-.64s'", nil))
// ErrIllegalMixCollation is returned when illegal mix of collations.
ErrIllegalMixCollation = dbterror.ClassExpression.NewStd(mysql.ErrCantAggregate2collations)
ErrIllegalMixCollation = dbterror.ClassExpression.NewStd(mysql.ErrCantAggregateNcollations)
// ErrIllegalMix2Collation is returned when illegal mix of 2 collations.
ErrIllegalMix2Collation = dbterror.ClassExpression.NewStd(mysql.ErrCantAggregate2collations)
// ErrIllegalMix3Collation is returned when illegal mix of 3 collations.
ErrIllegalMix3Collation = dbterror.ClassExpression.NewStd(mysql.ErrCantAggregate3collations)
)

// DefaultLen is set for datum if the string datum don't know its length.
Expand Down

0 comments on commit 187ef72

Please sign in to comment.