Skip to content

Commit

Permalink
expression: make field and findInSet support collation (#15100)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjhuang2016 authored Mar 4, 2020
1 parent efa811a commit 1771fff
Show file tree
Hide file tree
Showing 15 changed files with 99 additions and 87 deletions.
4 changes: 2 additions & 2 deletions executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ func (e *maxMin4String) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup
p.isNull = false
continue
}
cmp := types.CompareString(input, p.val, tp.Collate, tp.Flen)
cmp := types.CompareString(input, p.val, tp.Collate)
if e.isMax && cmp == 1 || !e.isMax && cmp == -1 {
p.val = stringutil.Copy(input)
}
Expand All @@ -457,7 +457,7 @@ func (e *maxMin4String) MergePartialResult(sctx sessionctx.Context, src, dst Par
return nil
}
tp := e.args[0].GetType()
cmp := types.CompareString(p1.val, p2.val, tp.Collate, tp.Flen)
cmp := types.CompareString(p1.val, p2.val, tp.Collate)
if e.isMax && cmp > 0 || !e.isMax && cmp < 0 {
p2.val, p2.isNull = p1.val, false
}
Expand Down
22 changes: 21 additions & 1 deletion expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tipb/go-tipb"
)

Expand All @@ -49,6 +50,7 @@ type baseBuiltinFunc struct {
ctx sessionctx.Context
tp *types.FieldType
pbCode tipb.ScalarFuncSig
ctor collate.Collator

childrenVectorized bool
childrenReversed bool
Expand Down Expand Up @@ -76,11 +78,20 @@ func (b *baseBuiltinFunc) setPbCode(c tipb.ScalarFuncSig) {
b.pbCode = c
}

func (b *baseBuiltinFunc) setCollator(ctor collate.Collator) {
b.ctor = ctor
}

func (b *baseBuiltinFunc) collator() collate.Collator {
return b.ctor
}

func newBaseBuiltinFunc(ctx sessionctx.Context, args []Expression) baseBuiltinFunc {
if ctx == nil {
panic("ctx should not be nil")
}
return baseBuiltinFunc{
derivedCharset, derivedCollate, derivedFlen := DeriveCollationFromExprs(ctx, args...)
bf := baseBuiltinFunc{
bufAllocator: newLocalSliceBuffer(len(args)),
childrenVectorizedOnce: new(sync.Once),
childrenReversedOnce: new(sync.Once),
Expand All @@ -89,6 +100,9 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, args []Expression) baseBuiltinFu
ctx: ctx,
tp: types.NewFieldType(mysql.TypeUnspecified),
}
bf.SetCharsetAndCollation(derivedCharset, derivedCollate, derivedFlen)
bf.setCollator(collate.GetCollator(derivedCollate))
return bf
}

// newBaseBuiltinFuncWithTp creates a built-in function signature with specified types of arguments and the return type of the function.
Expand Down Expand Up @@ -201,6 +215,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, args []Expression, retType
tp: fieldType,
}
bf.SetCharsetAndCollation(derivedCharset, derivedCollate, derivedFlen)
bf.setCollator(collate.GetCollator(derivedCollate))
return bf
}

Expand Down Expand Up @@ -345,6 +360,7 @@ func (b *baseBuiltinFunc) cloneFrom(from *baseBuiltinFunc) {
b.bufAllocator = newLocalSliceBuffer(len(b.args))
b.childrenVectorizedOnce = new(sync.Once)
b.childrenReversedOnce = new(sync.Once)
b.ctor = from.ctor
}

func (b *baseBuiltinFunc) Clone() builtinFunc {
Expand Down Expand Up @@ -452,6 +468,10 @@ type builtinFunc interface {
setPbCode(tipb.ScalarFuncSig)
// PbCode returns PbCode of this signature.
PbCode() tipb.ScalarFuncSig
// setCollator sets collator for signature.
setCollator(ctor collate.Collator)
// collator returns collator of this signature.
collator() collate.Collator
// metadata returns the metadata of a function.
// metadata means some functions contain extra inner fields which will not
// contain in `tipb.Expr.children` but must be pushed down to coprocessor
Expand Down
34 changes: 12 additions & 22 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,14 +561,13 @@ func (b *builtinGreatestStringSig) evalString(row chunk.Row) (max string, isNull
if isNull || err != nil {
return max, isNull, err
}
_, collation, flen := b.CharsetAndCollation(b.ctx)
for i := 1; i < len(b.args); i++ {
var v string
v, isNull, err = b.args[i].EvalString(b.ctx, row)
if isNull || err != nil {
return max, isNull, err
}
if types.CompareString(v, max, collation, flen) > 0 {
if types.CompareString(v, max, b.collation) > 0 {
max = v
}
}
Expand Down Expand Up @@ -761,14 +760,13 @@ func (b *builtinLeastStringSig) evalString(row chunk.Row) (min string, isNull bo
if isNull || err != nil {
return min, isNull, err
}
_, collation, flen := b.CharsetAndCollation(b.ctx)
for i := 1; i < len(b.args); i++ {
var v string
v, isNull, err = b.args[i].EvalString(b.ctx, row)
if isNull || err != nil {
return min, isNull, err
}
if types.CompareString(v, min, collation, flen) < 0 {
if types.CompareString(v, min, b.collation) < 0 {
min = v
}
}
Expand Down Expand Up @@ -1522,8 +1520,7 @@ func (b *builtinLTStringSig) Clone() builtinFunc {
}

func (b *builtinLTStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
_, collation, flen := b.CharsetAndCollation(b.ctx)
return resOfLT(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, collation, flen))
return resOfLT(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, b.collation))
}

type builtinLTDurationSig struct {
Expand Down Expand Up @@ -1621,8 +1618,7 @@ func (b *builtinLEStringSig) Clone() builtinFunc {
}

func (b *builtinLEStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
_, collation, flen := b.CharsetAndCollation(b.ctx)
return resOfLE(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, collation, flen))
return resOfLE(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, b.collation))
}

type builtinLEDurationSig struct {
Expand Down Expand Up @@ -1720,8 +1716,7 @@ func (b *builtinGTStringSig) Clone() builtinFunc {
}

func (b *builtinGTStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
_, collation, flen := b.CharsetAndCollation(b.ctx)
return resOfGT(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, collation, flen))
return resOfGT(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, b.collation))
}

type builtinGTDurationSig struct {
Expand Down Expand Up @@ -1819,8 +1814,7 @@ func (b *builtinGEStringSig) Clone() builtinFunc {
}

func (b *builtinGEStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
_, collation, flen := b.CharsetAndCollation(b.ctx)
return resOfGE(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, collation, flen))
return resOfGE(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, b.collation))
}

type builtinGEDurationSig struct {
Expand Down Expand Up @@ -1918,8 +1912,7 @@ func (b *builtinEQStringSig) Clone() builtinFunc {
}

func (b *builtinEQStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
_, collation, flen := b.CharsetAndCollation(b.ctx)
return resOfEQ(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, collation, flen))
return resOfEQ(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, b.collation))
}

type builtinEQDurationSig struct {
Expand Down Expand Up @@ -2017,8 +2010,7 @@ func (b *builtinNEStringSig) Clone() builtinFunc {
}

func (b *builtinNEStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
_, collation, flen := b.CharsetAndCollation(b.ctx)
return resOfNE(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, collation, flen))
return resOfNE(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, b.collation))
}

type builtinNEDurationSig struct {
Expand Down Expand Up @@ -2193,13 +2185,12 @@ func (b *builtinNullEQStringSig) evalInt(row chunk.Row) (val int64, isNull bool,
return 0, true, err
}
var res int64
_, collation, flen := b.CharsetAndCollation(b.ctx)
switch {
case isNull0 && isNull1:
res = 1
case isNull0 != isNull1:
break
case types.CompareString(arg0, arg1, collation, flen) == 0:
case types.CompareString(arg0, arg1, b.collation) == 0:
res = 1
}
return res, false, nil
Expand Down Expand Up @@ -2432,13 +2423,12 @@ func CompareInt(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsR

func genCompareString(collation string, flen int) func(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) {
return func(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) {
return CompareStringWithCollationInfo(sctx, lhsArg, rhsArg, lhsRow, rhsRow, collation, flen)
return CompareStringWithCollationInfo(sctx, lhsArg, rhsArg, lhsRow, rhsRow, collation)
}
}

// CompareStringWithCollationInfo compares two strings with the specified collation information.
func CompareStringWithCollationInfo(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row,
collation string, flen int) (int64, bool, error) {
func CompareStringWithCollationInfo(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row, collation string) (int64, bool, error) {
arg0, isNull0, err := lhsArg.EvalString(sctx, lhsRow)
if err != nil {
return 0, true, err
Expand All @@ -2452,7 +2442,7 @@ func CompareStringWithCollationInfo(sctx sessionctx.Context, lhsArg, rhsArg Expr
if isNull0 || isNull1 {
return compareNull(isNull0, isNull1), true, nil
}
return int64(types.CompareString(arg0, arg1, collation, flen)), false, nil
return int64(types.CompareString(arg0, arg1, collation)), false, nil
}

// CompareReal compares two float-point values.
Expand Down
6 changes: 2 additions & 4 deletions expression/builtin_compare_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ func (b *builtinLeastStringSig) vecEvalString(input *chunk.Chunk, result *chunk.
src := result
arg := buf1
dst := buf2
_, collation, flen := b.CharsetAndCollation(b.ctx)
for j := 1; j < len(b.args); j++ {
if err := b.args[j].VecEvalString(b.ctx, input, arg); err != nil {
return err
Expand All @@ -266,7 +265,7 @@ func (b *builtinLeastStringSig) vecEvalString(input *chunk.Chunk, result *chunk.
}
srcStr := src.GetString(i)
argStr := arg.GetString(i)
if types.CompareString(srcStr, argStr, collation, flen) < 0 {
if types.CompareString(srcStr, argStr, b.collation) < 0 {
dst.AppendString(srcStr)
} else {
dst.AppendString(argStr)
Expand Down Expand Up @@ -792,7 +791,6 @@ func (b *builtinGreatestStringSig) vecEvalString(input *chunk.Chunk, result *chu
src := result
arg := buf1
dst := buf2
_, collation, flen := b.CharsetAndCollation(b.ctx)
for j := 1; j < len(b.args); j++ {
if err := b.args[j].VecEvalString(b.ctx, input, arg); err != nil {
return err
Expand All @@ -804,7 +802,7 @@ func (b *builtinGreatestStringSig) vecEvalString(input *chunk.Chunk, result *chu
}
srcStr := src.GetString(i)
argStr := arg.GetString(i)
if types.CompareString(srcStr, argStr, collation, flen) > 0 {
if types.CompareString(srcStr, argStr, b.collation) > 0 {
dst.AppendString(srcStr)
} else {
dst.AppendString(argStr)
Expand Down
21 changes: 7 additions & 14 deletions expression/builtin_compare_vec_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions expression/builtin_other_vec_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 1771fff

Please sign in to comment.