Skip to content

Commit

Permalink
expression: allow function coercibility derive to DERIVATIO ... (#19462)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiongjiwei authored Sep 1, 2020
1 parent 2ae1cc1 commit 9c2d7c2
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 41 deletions.
22 changes: 6 additions & 16 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ func (b *baseBuiltinFunc) collator() collate.Collator {
return b.ctor
}

func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expression) (baseBuiltinFunc, error) {
func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expression, retType types.EvalType) (baseBuiltinFunc, error) {
if ctx == nil {
return baseBuiltinFunc{}, errors.New("unexpected nil session ctx")
}
if err := checkIllegalMixCollation(funcName, args); err != nil {
if err := checkIllegalMixCollation(funcName, args, retType); err != nil {
return baseBuiltinFunc{}, err
}
derivedCharset, derivedCollate := DeriveCollationFromExprs(ctx, args...)
Expand All @@ -109,29 +109,19 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expressi
}

var (
// allowDeriveNoneFunction stores functions which allow two incompatible collations which have the same charset derive to CoercibilityNone
allowDeriveNoneFunction = map[string]struct{}{
ast.Concat: {}, ast.ConcatWS: {}, ast.Reverse: {}, ast.Replace: {}, ast.InsertFunc: {}, ast.Lower: {},
ast.Upper: {}, ast.Left: {}, ast.Right: {}, ast.Substr: {}, ast.SubstringIndex: {}, ast.Trim: {},
ast.CurrentUser: {}, ast.Elt: {}, ast.MakeSet: {}, ast.Repeat: {}, ast.Rpad: {}, ast.Lpad: {},
ast.ExportSet: {},
}

coerString = []string{"EXPLICIT", "NONE", "IMPLICIT", "SYSCONST", "COERCIBLE", "NUMERIC", "IGNORABLE"}
)

func checkIllegalMixCollation(funcName string, args []Expression) error {
func checkIllegalMixCollation(funcName string, args []Expression, evalType types.EvalType) error {
if len(args) < 2 {
return nil
}
_, _, coercibility, legal := inferCollation(args...)
if !legal {
return illegalMixCollationErr(funcName, args)
}
if coercibility == CoercibilityNone {
if _, ok := allowDeriveNoneFunction[funcName]; !ok {
return illegalMixCollationErr(funcName, args)
}
if coercibility == CoercibilityNone && evalType != types.ETString {
return illegalMixCollationErr(funcName, args)
}
return nil
}
Expand Down Expand Up @@ -179,7 +169,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
}
}

if err = checkIllegalMixCollation(funcName, args); err != nil {
if err = checkIllegalMixCollation(funcName, args, retType); err != nil {
return
}

Expand Down
14 changes: 7 additions & 7 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (c *castAsIntFunctionClass) getFunction(ctx sessionctx.Context, args []Expr
if err := c.verifyArgs(args); err != nil {
return nil, err
}
b, err := newBaseBuiltinFunc(ctx, c.funcName, args)
b, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp.EvalType())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -167,7 +167,7 @@ func (c *castAsRealFunctionClass) getFunction(ctx sessionctx.Context, args []Exp
if err := c.verifyArgs(args); err != nil {
return nil, err
}
b, err := newBaseBuiltinFunc(ctx, c.funcName, args)
b, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp.EvalType())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -222,7 +222,7 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx sessionctx.Context, args []
if err := c.verifyArgs(args); err != nil {
return nil, err
}
b, err := newBaseBuiltinFunc(ctx, c.funcName, args)
b, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp.EvalType())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -277,7 +277,7 @@ func (c *castAsStringFunctionClass) getFunction(ctx sessionctx.Context, args []E
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args)
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp.EvalType())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -326,7 +326,7 @@ func (c *castAsTimeFunctionClass) getFunction(ctx sessionctx.Context, args []Exp
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args)
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp.EvalType())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -370,7 +370,7 @@ func (c *castAsDurationFunctionClass) getFunction(ctx sessionctx.Context, args [
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args)
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp.EvalType())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -414,7 +414,7 @@ func (c *castAsJSONFunctionClass) getFunction(ctx sessionctx.Context, args []Exp
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args)
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp.EvalType())
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_cast_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (

func genCastIntAsInt() (*builtinCastIntAsIntSig, *chunk.Chunk, *chunk.Column) {
col := &Column{RetType: types.NewFieldType(mysql.TypeLonglong), Index: 0}
baseFunc, err := newBaseBuiltinFunc(mock.NewContext(), "", []Expression{col})
baseFunc, err := newBaseBuiltinFunc(mock.NewContext(), "", []Expression{col}, 0)
if err != nil {
panic(err)
}
Expand Down
28 changes: 14 additions & 14 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
}
for i, t := range castToDecCases {
args := []Expression{t.before}
b, err := newBaseBuiltinFunc(ctx, "", args)
b, err := newBaseBuiltinFunc(ctx, "", args, 0)
c.Assert(err, IsNil)
decFunc := newBaseBuiltinCastFunc(b, false)
decFunc.tp = types.NewFieldType(mysql.TypeNewDecimal)
Expand Down Expand Up @@ -442,7 +442,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
args := []Expression{t.before}
tp := types.NewFieldType(mysql.TypeNewDecimal)
tp.Flen, tp.Decimal = t.flen, t.decimal
b, err := newBaseBuiltinFunc(ctx, "", args)
b, err := newBaseBuiltinFunc(ctx, "", args, 0)
c.Assert(err, IsNil)
decFunc := newBaseBuiltinCastFunc(b, false)
decFunc.tp = tp
Expand Down Expand Up @@ -512,7 +512,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
}
for i, t := range castToIntCases {
args := []Expression{t.before}
b, err := newBaseBuiltinFunc(ctx, "", args)
b, err := newBaseBuiltinFunc(ctx, "", args, 0)
c.Assert(err, IsNil)
intFunc := newBaseBuiltinCastFunc(b, false)
switch i {
Expand Down Expand Up @@ -580,7 +580,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
}
for i, t := range castToRealCases {
args := []Expression{t.before}
b, err := newBaseBuiltinFunc(ctx, "", args)
b, err := newBaseBuiltinFunc(ctx, "", args, 0)
c.Assert(err, IsNil)
realFunc := newBaseBuiltinCastFunc(b, false)
switch i {
Expand Down Expand Up @@ -656,7 +656,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
tp := types.NewFieldType(mysql.TypeVarString)
tp.Charset = charset.CharsetBin
args := []Expression{t.before}
stringFunc, err := newBaseBuiltinFunc(ctx, "", args)
stringFunc, err := newBaseBuiltinFunc(ctx, "", args, 0)
c.Assert(err, IsNil)
stringFunc.tp = tp
switch i {
Expand Down Expand Up @@ -735,7 +735,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
args := []Expression{t.before}
tp := types.NewFieldType(mysql.TypeVarString)
tp.Flen, tp.Charset = t.flen, charset.CharsetBin
stringFunc, err := newBaseBuiltinFunc(ctx, "", args)
stringFunc, err := newBaseBuiltinFunc(ctx, "", args, 0)
c.Assert(err, IsNil)
stringFunc.tp = tp
switch i {
Expand Down Expand Up @@ -811,7 +811,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
args := []Expression{t.before}
tp := types.NewFieldType(mysql.TypeDatetime)
tp.Decimal = int(types.DefaultFsp)
timeFunc, err := newBaseBuiltinFunc(ctx, "", args)
timeFunc, err := newBaseBuiltinFunc(ctx, "", args, 0)
c.Assert(err, IsNil)
timeFunc.tp = tp
switch i {
Expand Down Expand Up @@ -896,7 +896,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
args := []Expression{t.before}
tp := types.NewFieldType(t.tp)
tp.Decimal = int(t.fsp)
timeFunc, err := newBaseBuiltinFunc(ctx, "", args)
timeFunc, err := newBaseBuiltinFunc(ctx, "", args, 0)
c.Assert(err, IsNil)
timeFunc.tp = tp
switch i {
Expand Down Expand Up @@ -978,7 +978,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
args := []Expression{t.before}
tp := types.NewFieldType(mysql.TypeDuration)
tp.Decimal = int(types.DefaultFsp)
durationFunc, err := newBaseBuiltinFunc(ctx, "", args)
durationFunc, err := newBaseBuiltinFunc(ctx, "", args, 0)
c.Assert(err, IsNil)
durationFunc.tp = tp
switch i {
Expand Down Expand Up @@ -1056,7 +1056,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
args := []Expression{t.before}
tp := types.NewFieldType(mysql.TypeDuration)
tp.Decimal = t.fsp
durationFunc, err := newBaseBuiltinFunc(ctx, "", args)
durationFunc, err := newBaseBuiltinFunc(ctx, "", args, 0)
c.Assert(err, IsNil)
durationFunc.tp = tp
switch i {
Expand Down Expand Up @@ -1089,7 +1089,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
// null case
args := []Expression{&Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}}
row := chunk.MutRowFromDatums([]types.Datum{types.NewDatum(nil)})
bf, err := newBaseBuiltinFunc(ctx, "", args)
bf, err := newBaseBuiltinFunc(ctx, "", args, 0)
c.Assert(err, IsNil)
bf.tp = types.NewFieldType(mysql.TypeVarString)
sig = &builtinCastRealAsStringSig{bf}
Expand All @@ -1100,7 +1100,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {

// test hybridType case.
args = []Expression{&Constant{Value: types.NewDatum(types.Enum{Name: "a", Value: 0}), RetType: types.NewFieldType(mysql.TypeEnum)}}
b, err := newBaseBuiltinFunc(ctx, "", args)
b, err := newBaseBuiltinFunc(ctx, "", args, 0)
c.Assert(err, IsNil)
sig = &builtinCastStringAsIntSig{newBaseBuiltinCastFunc(b, false)}
iRes, isNull, err := sig.evalInt(chunk.Row{})
Expand All @@ -1118,7 +1118,7 @@ func (s *testEvaluatorSuite) TestCastJSONAsDecimalSig(c *C) {
}()

col := &Column{RetType: types.NewFieldType(mysql.TypeJSON), Index: 0}
b, err := newBaseBuiltinFunc(ctx, "", []Expression{col})
b, err := newBaseBuiltinFunc(ctx, "", []Expression{col}, 0)
c.Assert(err, IsNil)
decFunc := newBaseBuiltinCastFunc(b, false)
decFunc.tp = types.NewFieldType(mysql.TypeNewDecimal)
Expand Down Expand Up @@ -1420,7 +1420,7 @@ func (s *testEvaluatorSuite) TestCastIntAsIntVec(c *C) {
// for issue https://github.com/pingcap/tidb/issues/16825
func (s *testEvaluatorSuite) TestCastStringAsDecimalSigWithUnsignedFlagInUnion(c *C) {
col := &Column{RetType: types.NewFieldType(mysql.TypeString), Index: 0}
b, err := newBaseBuiltinFunc(mock.NewContext(), "", []Expression{col})
b, err := newBaseBuiltinFunc(mock.NewContext(), "", []Expression{col}, 0)
c.Assert(err, IsNil)
// set `inUnion` to `true`
decFunc := newBaseBuiltinCastFunc(b, true)
Expand Down
4 changes: 2 additions & 2 deletions expression/builtin_cast_vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (s *testEvaluatorSuite) TestVectorizedBuiltinCastFunc(c *C) {

func (s *testEvaluatorSuite) TestVectorizedCastRealAsTime(c *C) {
col := &Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}
baseFunc, err := newBaseBuiltinFunc(mock.NewContext(), "", []Expression{col})
baseFunc, err := newBaseBuiltinFunc(mock.NewContext(), "", []Expression{col}, 0)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -199,7 +199,7 @@ func genCastRealAsTime() *chunk.Chunk {
// for issue https://github.com/pingcap/tidb/issues/16825
func (s *testEvaluatorSuite) TestVectorizedCastStringAsDecimalWithUnsignedFlagInUnion(c *C) {
col := &Column{RetType: types.NewFieldType(mysql.TypeString), Index: 0}
baseFunc, err := newBaseBuiltinFunc(mock.NewContext(), "", []Expression{col})
baseFunc, err := newBaseBuiltinFunc(mock.NewContext(), "", []Expression{col}, 0)
if err != nil {
panic(err)
}
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ func (c *valuesFunctionClass) getFunction(ctx sessionctx.Context, args []Express
if err = c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args)
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp.EvalType())
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6063,6 +6063,8 @@ func (s *testIntegrationSerialSuite) TestMixCollation(c *C) {
tk.MustQuery("select coercibility(concat(mb4unicode, mb4bin, concat(mb4general))) from t;").Check(testkit.Rows("2"))
tk.MustQuery("select collation(concat(mb4unicode, mb4bin, concat(mb4general))) from t;").Check(testkit.Rows("utf8mb4_bin"))
tk.MustQuery("select coercibility(concat(mb4unicode, mb4general)) from t;").Check(testkit.Rows("1"))
tk.MustQuery("select collation(coalesce(mb4unicode, mb4general)) from t;").Check(testkit.Rows("utf8mb4_bin"))
tk.MustQuery("select coercibility(coalesce(mb4unicode, mb4general)) from t;").Check(testkit.Rows("1"))
tk.MustQuery("select collation(CONCAT(concat(mb4unicode), concat(mb4general))) from t;").Check(testkit.Rows("utf8mb4_bin"))
tk.MustQuery("select coercibility(cONcat(unicode, general)) from t;").Check(testkit.Rows("1"))
tk.MustQuery("select collation(concAt(unicode, general)) from t;").Check(testkit.Rows("utf8_bin"))
Expand Down

0 comments on commit 9c2d7c2

Please sign in to comment.