Skip to content

Commit

Permalink
expression: consider collations when comparing strings (#14913)
Browse files Browse the repository at this point in the history
  • Loading branch information
qw4990 authored Feb 25, 2020
1 parent 51a1323 commit f2fa5c5
Show file tree
Hide file tree
Showing 18 changed files with 148 additions and 63 deletions.
6 changes: 4 additions & 2 deletions executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ func (e *maxMin4String) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Part

func (e *maxMin4String) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4MaxMinString)(pr)
tp := e.args[0].GetType()
for _, row := range rowsInGroup {
input, isNull, err := e.args[0].EvalString(sctx, row)
if err != nil {
Expand All @@ -438,7 +439,7 @@ func (e *maxMin4String) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup
p.isNull = false
continue
}
cmp := types.CompareString(input, p.val)
cmp := types.CompareString(input, p.val, tp.Collate, tp.Flen)
if e.isMax && cmp == 1 || !e.isMax && cmp == -1 {
p.val = stringutil.Copy(input)
}
Expand All @@ -455,7 +456,8 @@ func (e *maxMin4String) MergePartialResult(sctx sessionctx.Context, src, dst Par
*p2 = *p1
return nil
}
cmp := types.CompareString(p1.val, p2.val)
tp := e.args[0].GetType()
cmp := types.CompareString(p1.val, p2.val, tp.Collate, tp.Flen)
if e.isMax && cmp > 0 || !e.isMax && cmp < 0 {
p2.val, p2.isNull = p1.val, false
}
Expand Down
8 changes: 4 additions & 4 deletions executor/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1206,8 +1206,8 @@ func prepare4IndexMergeJoin(tc *indexJoinTestCase, outerDS *mockDataSource, inne
compareFuncs := make([]expression.CompareFunc, 0, len(outerJoinKeys))
outerCompareFuncs := make([]expression.CompareFunc, 0, len(outerJoinKeys))
for i := range outerJoinKeys {
compareFuncs = append(compareFuncs, expression.GetCmpFunction(outerJoinKeys[i], innerJoinKeys[i]))
outerCompareFuncs = append(outerCompareFuncs, expression.GetCmpFunction(outerJoinKeys[i], outerJoinKeys[i]))
compareFuncs = append(compareFuncs, expression.GetCmpFunction(nil, outerJoinKeys[i], innerJoinKeys[i]))
outerCompareFuncs = append(outerCompareFuncs, expression.GetCmpFunction(nil, outerJoinKeys[i], outerJoinKeys[i]))
}
e := &IndexLookUpMergeJoin{
baseExecutor: newBaseExecutor(tc.ctx, joinSchema, stringutil.StringerStr("IndexMergeJoin"), outerDS),
Expand Down Expand Up @@ -1343,8 +1343,8 @@ func prepare4MergeJoin(tc *mergeJoinTestCase, leftExec, rightExec *mockDataSourc
compareFuncs := make([]expression.CompareFunc, 0, len(outerJoinKeys))
outerCompareFuncs := make([]expression.CompareFunc, 0, len(outerJoinKeys))
for i := range outerJoinKeys {
compareFuncs = append(compareFuncs, expression.GetCmpFunction(outerJoinKeys[i], innerJoinKeys[i]))
outerCompareFuncs = append(outerCompareFuncs, expression.GetCmpFunction(outerJoinKeys[i], outerJoinKeys[i]))
compareFuncs = append(compareFuncs, expression.GetCmpFunction(nil, outerJoinKeys[i], innerJoinKeys[i]))
outerCompareFuncs = append(outerCompareFuncs, expression.GetCmpFunction(nil, outerJoinKeys[i], outerJoinKeys[i]))
}

defaultValues := make([]types.Datum, len(innerCols))
Expand Down
2 changes: 1 addition & 1 deletion executor/executor_required_rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ func buildMergeJoinExec(ctx sessionctx.Context, joinType plannercore.JoinType, i

j.CompareFuncs = make([]expression.CompareFunc, 0, len(j.LeftJoinKeys))
for i := range j.LeftJoinKeys {
j.CompareFuncs = append(j.CompareFuncs, expression.GetCmpFunction(j.LeftJoinKeys[i], j.RightJoinKeys[i]))
j.CompareFuncs = append(j.CompareFuncs, expression.GetCmpFunction(nil, j.LeftJoinKeys[i], j.RightJoinKeys[i]))
}

b := newExecutorBuilder(ctx, nil)
Expand Down
36 changes: 22 additions & 14 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ func (b *builtinGreatestStringSig) evalString(row chunk.Row) (max string, isNull
if isNull || err != nil {
return max, isNull, err
}
if types.CompareString(v, max) > 0 {
if types.CompareString(v, max, b.tp.Collate, b.tp.Flen) > 0 {
max = v
}
}
Expand Down Expand Up @@ -766,7 +766,7 @@ func (b *builtinLeastStringSig) evalString(row chunk.Row) (min string, isNull bo
if isNull || err != nil {
return min, isNull, err
}
if types.CompareString(v, min) < 0 {
if types.CompareString(v, min, b.tp.Collate, b.tp.Flen) < 0 {
min = v
}
}
Expand Down Expand Up @@ -1044,7 +1044,7 @@ func GetAccurateCmpType(lhs, rhs Expression) types.EvalType {
}

// GetCmpFunction get the compare function according to two arguments.
func GetCmpFunction(lhs, rhs Expression) CompareFunc {
func GetCmpFunction(ctx sessionctx.Context, lhs, rhs Expression) CompareFunc {
switch GetAccurateCmpType(lhs, rhs) {
case types.ETInt:
return CompareInt
Expand All @@ -1053,7 +1053,8 @@ func GetCmpFunction(lhs, rhs Expression) CompareFunc {
case types.ETDecimal:
return CompareDecimal
case types.ETString:
return CompareString
_, dstCollation, dstFlen := DeriveCollationFromExprs(ctx, lhs, rhs)
return genCompareString(dstCollation, dstFlen)
case types.ETDuration:
return CompareDuration
case types.ETDatetime, types.ETTimestamp:
Expand Down Expand Up @@ -1519,7 +1520,7 @@ func (b *builtinLTStringSig) Clone() builtinFunc {
}

func (b *builtinLTStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
return resOfLT(CompareString(b.ctx, b.args[0], b.args[1], row, row))
return resOfLT(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, b.tp.Collate, b.tp.Flen))
}

type builtinLTDurationSig struct {
Expand Down Expand Up @@ -1617,7 +1618,7 @@ func (b *builtinLEStringSig) Clone() builtinFunc {
}

func (b *builtinLEStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
return resOfLE(CompareString(b.ctx, b.args[0], b.args[1], row, row))
return resOfLE(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, b.tp.Collate, b.tp.Flen))
}

type builtinLEDurationSig struct {
Expand Down Expand Up @@ -1715,7 +1716,7 @@ func (b *builtinGTStringSig) Clone() builtinFunc {
}

func (b *builtinGTStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
return resOfGT(CompareString(b.ctx, b.args[0], b.args[1], row, row))
return resOfGT(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, b.tp.Collate, b.tp.Flen))
}

type builtinGTDurationSig struct {
Expand Down Expand Up @@ -1813,7 +1814,7 @@ func (b *builtinGEStringSig) Clone() builtinFunc {
}

func (b *builtinGEStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
return resOfGE(CompareString(b.ctx, b.args[0], b.args[1], row, row))
return resOfGE(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, b.tp.Collate, b.tp.Flen))
}

type builtinGEDurationSig struct {
Expand Down Expand Up @@ -1911,7 +1912,7 @@ func (b *builtinEQStringSig) Clone() builtinFunc {
}

func (b *builtinEQStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
return resOfEQ(CompareString(b.ctx, b.args[0], b.args[1], row, row))
return resOfEQ(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, b.tp.Collate, b.tp.Flen))
}

type builtinEQDurationSig struct {
Expand Down Expand Up @@ -2009,7 +2010,7 @@ func (b *builtinNEStringSig) Clone() builtinFunc {
}

func (b *builtinNEStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
return resOfNE(CompareString(b.ctx, b.args[0], b.args[1], row, row))
return resOfNE(CompareStringWithCollationInfo(b.ctx, b.args[0], b.args[1], row, row, b.tp.Collate, b.tp.Flen))
}

type builtinNEDurationSig struct {
Expand Down Expand Up @@ -2189,7 +2190,7 @@ func (b *builtinNullEQStringSig) evalInt(row chunk.Row) (val int64, isNull bool,
res = 1
case isNull0 != isNull1:
break
case types.CompareString(arg0, arg1) == 0:
case types.CompareString(arg0, arg1, b.tp.Collate, b.tp.Flen) == 0:
res = 1
}
return res, false, nil
Expand Down Expand Up @@ -2420,8 +2421,15 @@ func CompareInt(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsR
return int64(res), false, nil
}

// CompareString compares two strings.
func CompareString(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) {
func genCompareString(collation string, flen int) func(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) {
return func(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) {
return CompareStringWithCollationInfo(sctx, lhsArg, rhsArg, lhsRow, rhsRow, collation, flen)
}
}

// CompareStringWithCollationInfo compares two strings with the specified collation information.
func CompareStringWithCollationInfo(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row,
collation string, flen int) (int64, bool, error) {
arg0, isNull0, err := lhsArg.EvalString(sctx, lhsRow)
if err != nil {
return 0, true, err
Expand All @@ -2435,7 +2443,7 @@ func CompareString(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, r
if isNull0 || isNull1 {
return compareNull(isNull0, isNull1), true, nil
}
return int64(types.CompareString(arg0, arg1)), false, nil
return int64(types.CompareString(arg0, arg1, collation, flen)), false, nil
}

// CompareReal compares two float-point values.
Expand Down
4 changes: 2 additions & 2 deletions expression/builtin_compare_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func (b *builtinLeastStringSig) vecEvalString(input *chunk.Chunk, result *chunk.
}
srcStr := src.GetString(i)
argStr := arg.GetString(i)
if types.CompareString(srcStr, argStr) < 0 {
if types.CompareString(srcStr, argStr, b.tp.Collate, b.tp.Flen) < 0 {
dst.AppendString(srcStr)
} else {
dst.AppendString(argStr)
Expand Down Expand Up @@ -802,7 +802,7 @@ func (b *builtinGreatestStringSig) vecEvalString(input *chunk.Chunk, result *chu
}
srcStr := src.GetString(i)
argStr := arg.GetString(i)
if types.CompareString(srcStr, argStr) > 0 {
if types.CompareString(srcStr, argStr, b.tp.Collate, b.tp.Flen) > 0 {
dst.AppendString(srcStr)
} else {
dst.AppendString(argStr)
Expand Down
14 changes: 7 additions & 7 deletions expression/builtin_compare_vec_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion expression/builtin_other_vec_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ func (b *builtinStrcmpSig) evalInt(row chunk.Row) (int64, bool, error) {
if isNull || err != nil {
return 0, isNull, err
}
res := types.CompareString(left, right)
res := types.CompareString(left, right, b.tp.Collate, b.tp.Flen)
return int64(res), false, nil
}

Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_string_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,7 @@ func (b *builtinStrcmpSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column)
if result.IsNull(i) {
continue
}
i64s[i] = int64(types.CompareString(leftBuf.GetString(i), rightBuf.GetString(i)))
i64s[i] = int64(types.CompareString(leftBuf.GetString(i), rightBuf.GetString(i), b.tp.Collate, b.tp.Flen))
}
return nil
}
Expand Down
6 changes: 5 additions & 1 deletion expression/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"strings"

"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/charset"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
)
Expand Down Expand Up @@ -115,7 +116,10 @@ func deriveCoercibilityForColumn(c *Column) Coercibility {
// DeriveCollationFromExprs derives collation information from these expressions.
func DeriveCollationFromExprs(ctx sessionctx.Context, exprs ...Expression) (dstCharset, dstCollation string, dstFlen int) {
curCoer := CoercibilityCoercible
dstCharset, dstCollation = ctx.GetSessionVars().GetCharsetInfo()
dstCharset, dstCollation = charset.GetDefaultCharsetAndCollate()
if ctx != nil && ctx.GetSessionVars() != nil {
dstCharset, dstCollation = ctx.GetSessionVars().GetCharsetInfo()
}
dstFlen = types.UnspecifiedLength
// see https://dev.mysql.com/doc/refman/8.0/en/charset-collation-coercibility.html
for _, e := range exprs {
Expand Down
67 changes: 67 additions & 0 deletions expression/collation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2020 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package expression

import (
. "github.com/pingcap/check"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/mock"
)

var _ = SerialSuites(&testCollationSuites{})

type testCollationSuites struct{}

func (s *testCollationSuites) TestCompareString(c *C) {
collate.SetNewCollationEnabledForTest(true)
defer collate.SetNewCollationEnabledForTest(false)

c.Assert(types.CompareString("a", "A", "utf8_general_ci", 1), Equals, 0)
c.Assert(types.CompareString("À", "A", "utf8_general_ci", 1), Equals, 0)
c.Assert(types.CompareString("😜", "😃", "utf8_general_ci", 1), Equals, 0)
c.Assert(types.CompareString("a ", "a ", "utf8_general_ci", 3), Equals, 0)
c.Assert(types.CompareString("a", "A", "binary", 1), Not(Equals), 0)
c.Assert(types.CompareString("À", "A", "binary", 1), Not(Equals), 0)
c.Assert(types.CompareString("😜", "😃", "binary", 1), Not(Equals), 0)
c.Assert(types.CompareString("a ", "a ", "binary", 3), Not(Equals), 0)

ctx := mock.NewContext()
ft := types.NewFieldType(mysql.TypeVarString)
col1 := &Column{
RetType: ft,
Index: 0,
}
col2 := &Column{
RetType: ft,
Index: 1,
}
chk := chunk.NewChunkWithCapacity([]*types.FieldType{ft, ft}, 4)
chk.Column(0).AppendString("a")
chk.Column(1).AppendString("A")
chk.Column(0).AppendString("À")
chk.Column(1).AppendString("A")
chk.Column(0).AppendString("😜")
chk.Column(1).AppendString("😃")
chk.Column(0).AppendString("a ")
chk.Column(1).AppendString("a ")
for i := 0; i < 4; i++ {
v, isNull, err := CompareStringWithCollationInfo(ctx, col1, col2, chk.GetRow(0), chk.GetRow(0), "utf8_general_ci", 3)
c.Assert(err, IsNil)
c.Assert(isNull, IsFalse)
c.Assert(v, Equals, int64(0))
}
}
4 changes: 2 additions & 2 deletions expression/generator/compare_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEvalInt(in
{{- else if eq .type.ETName "Real" }}
val := types.CompareFloat64(arg0[i], arg1[i])
{{- else if eq .type.ETName "String" }}
val := types.CompareString(buf0.GetString(i), buf1.GetString(i))
val := types.CompareString(buf0.GetString(i), buf1.GetString(i), b.tp.Collate, b.tp.Flen)
{{- else if eq .type.ETName "Duration" }}
val := types.CompareDuration(arg0[i], arg1[i])
{{- else if eq .type.ETName "Datetime" }}
Expand Down Expand Up @@ -151,7 +151,7 @@ func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEvalInt(in
{{- else if eq .type.ETName "Real" }}
case types.CompareFloat64(arg0[i], arg1[i]) == 0:
{{- else if eq .type.ETName "String" }}
case types.CompareString(buf0.GetString(i), buf1.GetString(i)) == 0:
case types.CompareString(buf0.GetString(i), buf1.GetString(i), b.tp.Collate, b.tp.Flen) == 0:
{{- else if eq .type.ETName "Duration" }}
case types.CompareDuration(arg0[i], arg1[i]) == 0:
{{- else if eq .type.ETName "Datetime" }}
Expand Down
Loading

0 comments on commit f2fa5c5

Please sign in to comment.