Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner, type: fix AggFieldType error when encouter unsigned and sign type (#21062) #21236

Merged
merged 4 commits into from
Nov 26, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,21 @@ import (

// NewOne stands for a number 1.
func NewOne() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
retT.Flag |= mysql.UnsignedFlag // shrink range to avoid integral promotion
return &Constant{
Value: types.NewDatum(1),
RetType: types.NewFieldType(mysql.TypeTiny),
RetType: retT,
}
}

// NewZero stands for a number 0.
func NewZero() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
retT.Flag |= mysql.UnsignedFlag // shrink range to avoid integral promotion
return &Constant{
Value: types.NewDatum(0),
RetType: types.NewFieldType(mysql.TypeTiny),
RetType: retT,
}
}

Expand Down
36 changes: 36 additions & 0 deletions planner/core/expression_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,39 @@ func (s *testExpressionRewriterSuite) TestIssue20007(c *C) {
testkit.Rows("2 epic wiles 2020-01-02 23:29:51", "3 silly burnell 2020-02-25 07:43:07"))
}
}

func (s *testExpressionRewriterSuite) TestIssue9869(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
defer func() {
dom.Close()
store.Close()
}()

tk.MustExec("use test;")
tk.MustExec("drop table if exists t1;")
tk.MustExec("create table t1(a int, b bigint unsigned);")
tk.MustExec("insert into t1 (a, b) values (1,4572794622775114594), (2,18196094287899841997),(3,11120436154190595086);")
tk.MustQuery("select (case t1.a when 0 then 0 else t1.b end), cast(t1.b as signed) from t1;").Check(
testkit.Rows("4572794622775114594 4572794622775114594", "18196094287899841997 -250649785809709619", "11120436154190595086 -7326307919518956530"))
}

func (s *testExpressionRewriterSuite) TestIssue17652(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
defer func() {
dom.Close()
store.Close()
}()

tk.MustExec("use test;")
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(x bigint unsigned);")
tk.MustExec("insert into t values( 9999999703771440633);")
tk.MustQuery("select ifnull(max(x), 0) from t").Check(
testkit.Rows("9999999703771440633"))
}
9 changes: 9 additions & 0 deletions types/etc.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ func IsTypeTime(tp byte) bool {
return tp == mysql.TypeDatetime || tp == mysql.TypeDate || tp == mysql.TypeTimestamp
}

// IsTypeInteger returns a boolean indicating whether the tp is integer type.
func IsTypeInteger(tp byte) bool {
switch tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
return true
}
return false
}

// IsTypeNumeric returns a boolean indicating whether the tp is numeric type.
func IsTypeNumeric(tp byte) bool {
switch tp {
Expand Down
29 changes: 26 additions & 3 deletions types/field_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,38 @@ func NewFieldTypeWithCollation(tp byte, collation string, length int) *FieldType
// Aggregation is performed by MergeFieldType function.
func AggFieldType(tps []*FieldType) *FieldType {
var currType FieldType
isMixedSign := false
for i, t := range tps {
if i == 0 && currType.Tp == mysql.TypeUnspecified {
currType = *t
continue
}
mtp := MergeFieldType(currType.Tp, t.Tp)
isMixedSign = isMixedSign || (mysql.HasUnsignedFlag(currType.Flag) != mysql.HasUnsignedFlag(t.Flag))
currType.Tp = mtp
currType.Flag = mergeTypeFlag(currType.Flag, t.Flag)
}
// integral promotion when tps contains signed and unsigned
if isMixedSign && IsTypeInteger(currType.Tp) {
bumpRange := false // indicate one of tps bump currType range
for _, t := range tps {
bumpRange = bumpRange || (mysql.HasUnsignedFlag(t.Flag) && (t.Tp == currType.Tp || t.Tp == mysql.TypeBit))
}
if bumpRange {
switch currType.Tp {
case mysql.TypeTiny:
currType.Tp = mysql.TypeShort
case mysql.TypeShort:
currType.Tp = mysql.TypeInt24
case mysql.TypeInt24:
currType.Tp = mysql.TypeLong
case mysql.TypeLong:
currType.Tp = mysql.TypeLonglong
case mysql.TypeLonglong:
currType.Tp = mysql.TypeNewDecimal
}
}
}

return &currType
}
Expand Down Expand Up @@ -310,10 +333,10 @@ func MergeFieldType(a byte, b byte) byte {
}

// mergeTypeFlag merges two MySQL type flag to a new one
// currently only NotNullFlag is checked
// todo more flag need to be checked, for example: UnsignedFlag
// currently only NotNullFlag and UnsignedFlag is checked
// todo more flag need to be checked
func mergeTypeFlag(a, b uint) uint {
return a & (b&mysql.NotNullFlag | ^mysql.NotNullFlag)
return a & (b&mysql.NotNullFlag | ^mysql.NotNullFlag) & (b&mysql.UnsignedFlag | ^mysql.UnsignedFlag)
}

func getFieldTypeIndex(tp byte) int {
Expand Down
38 changes: 38 additions & 0 deletions types/field_type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,44 @@ func (s *testFieldTypeSuite) TestAggFieldTypeForTypeFlag(c *C) {
c.Assert(aggTp.Flag, Equals, mysql.NotNullFlag)
}

func (s testFieldTypeSuite) TestAggFieldTypeForIntegralPromotion(c *C) {
fts := []*FieldType{
NewFieldType(mysql.TypeTiny),
NewFieldType(mysql.TypeShort),
NewFieldType(mysql.TypeInt24),
NewFieldType(mysql.TypeLong),
NewFieldType(mysql.TypeLonglong),
NewFieldType(mysql.TypeNewDecimal),
}

for i := 1; i < len(fts)-1; i++ {
tps := fts[i-1 : i+1]

tps[0].Flag = 0
tps[1].Flag = 0
aggTp := AggFieldType(tps)
c.Assert(aggTp.Tp, Equals, fts[i].Tp)
c.Assert(aggTp.Flag, Equals, uint(0))

tps[0].Flag = mysql.UnsignedFlag
aggTp = AggFieldType(tps)
c.Assert(aggTp.Tp, Equals, fts[i].Tp)
c.Assert(aggTp.Flag, Equals, uint(0))

tps[0].Flag = mysql.UnsignedFlag
tps[1].Flag = mysql.UnsignedFlag
aggTp = AggFieldType(tps)
c.Assert(aggTp.Tp, Equals, fts[i].Tp)
c.Assert(aggTp.Flag, Equals, mysql.UnsignedFlag)

tps[0].Flag = 0
tps[1].Flag = mysql.UnsignedFlag
aggTp = AggFieldType(tps)
c.Assert(aggTp.Tp, Equals, fts[i+1].Tp)
c.Assert(aggTp.Flag, Equals, uint(0))
}
}

func (s *testFieldTypeSuite) TestAggregateEvalType(c *C) {
defer testleak.AfterTest(c)()
fts := []*FieldType{
Expand Down