Skip to content

Commit

Permalink
executor: check for null values when comparing different groups durin…
Browse files Browse the repository at this point in the history
…g streamAgg (#15742)
  • Loading branch information
Reminiscent authored Mar 27, 2020
1 parent 7223e7f commit 0d10f91
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 26 deletions.
96 changes: 73 additions & 23 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -1111,8 +1111,15 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
return err
}

previousIsNull := col.IsNull(0)
var firstRowDatum, lastRowDatum types.Datum
firstRowIsNull, lastRowIsNull := col.IsNull(0), col.IsNull(numRows-1)
if firstRowIsNull {
firstRowDatum.SetNull()
}
if lastRowIsNull {
lastRowDatum.SetNull()
}
previousIsNull := firstRowIsNull
switch eType {
case types.ETInt:
vals := col.Int64s()
Expand All @@ -1128,8 +1135,12 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
firstRowDatum.SetInt64(vals[0])
lastRowDatum.SetInt64(vals[numRows-1])
if !firstRowIsNull {
firstRowDatum.SetInt64(vals[0])
}
if !lastRowIsNull {
lastRowDatum.SetInt64(vals[numRows-1])
}
case types.ETReal:
vals := col.Float64s()
for i := 1; i < numRows; i++ {
Expand All @@ -1144,8 +1155,12 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
firstRowDatum.SetFloat64(vals[0])
lastRowDatum.SetFloat64(vals[numRows-1])
if !firstRowIsNull {
firstRowDatum.SetFloat64(vals[0])
}
if !lastRowIsNull {
lastRowDatum.SetFloat64(vals[numRows-1])
}
case types.ETDecimal:
vals := col.Decimals()
for i := 1; i < numRows; i++ {
Expand All @@ -1160,10 +1175,16 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
// make a copy to avoid DATA RACE
firstDatum, lastDatum := vals[0], vals[numRows-1]
firstRowDatum.SetMysqlDecimal(&firstDatum)
lastRowDatum.SetMysqlDecimal(&lastDatum)
if !firstRowIsNull {
// make a copy to avoid DATA RACE
firstDatum := vals[0]
firstRowDatum.SetMysqlDecimal(&firstDatum)
}
if !lastRowIsNull {
// make a copy to avoid DATA RACE
lastDatum := vals[numRows-1]
lastRowDatum.SetMysqlDecimal(&lastDatum)
}
case types.ETDatetime, types.ETTimestamp:
vals := col.Times()
for i := 1; i < numRows; i++ {
Expand All @@ -1178,8 +1199,12 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
firstRowDatum.SetMysqlTime(vals[0])
lastRowDatum.SetMysqlTime(vals[numRows-1])
if !firstRowIsNull {
firstRowDatum.SetMysqlTime(vals[0])
}
if !lastRowIsNull {
lastRowDatum.SetMysqlTime(vals[numRows-1])
}
case types.ETDuration:
vals := col.GoDurations()
for i := 1; i < numRows; i++ {
Expand All @@ -1194,24 +1219,44 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
firstRowDatum.SetMysqlDuration(types.Duration{Duration: vals[0], Fsp: int8(item.GetType().Decimal)})
lastRowDatum.SetMysqlDuration(types.Duration{Duration: vals[numRows-1], Fsp: int8(item.GetType().Decimal)})
if !firstRowIsNull {
firstRowDatum.SetMysqlDuration(types.Duration{Duration: vals[0], Fsp: int8(item.GetType().Decimal)})
}
if !lastRowIsNull {
lastRowDatum.SetMysqlDuration(types.Duration{Duration: vals[numRows-1], Fsp: int8(item.GetType().Decimal)})
}
case types.ETJson:
previousKey := col.GetJSON(0)
var previousKey, key json.BinaryJSON
if !previousIsNull {
previousKey = col.GetJSON(0)
}
for i := 1; i < numRows; i++ {
key := col.GetJSON(i)
isNull := col.IsNull(i)
if !isNull {
key = col.GetJSON(i)
}
if e.sameGroup[i] {
if isNull != previousIsNull || json.CompareBinary(previousKey, key) != 0 {
if isNull == previousIsNull {
if !isNull && json.CompareBinary(previousKey, key) != 0 {
e.sameGroup[i] = false
}
} else {
e.sameGroup[i] = false
}
}
previousKey = key
if !isNull {
previousKey = key
}
previousIsNull = isNull
}
// make a copy to avoid DATA RACE
firstRowDatum.SetMysqlJSON(col.GetJSON(0).Copy())
lastRowDatum.SetMysqlJSON(col.GetJSON(numRows - 1).Copy())
if !firstRowIsNull {
// make a copy to avoid DATA RACE
firstRowDatum.SetMysqlJSON(col.GetJSON(0).Copy())
}
if !lastRowIsNull {
// make a copy to avoid DATA RACE
lastRowDatum.SetMysqlJSON(col.GetJSON(numRows - 1).Copy())
}
case types.ETString:
previousKey := codec.ConvertByCollationStr(col.GetString(0), tp)
for i := 1; i < numRows; i++ {
Expand All @@ -1225,9 +1270,14 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
previousKey = key
previousIsNull = isNull
}
// don't use col.GetString since it will cause DATA RACE
firstRowDatum.SetString(string(col.GetBytes(0)), tp.Collate)
lastRowDatum.SetString(string(col.GetBytes(numRows-1)), tp.Collate)
if !firstRowIsNull {
// don't use col.GetString since it will cause DATA RACE
firstRowDatum.SetString(string(col.GetBytes(0)), tp.Collate)
}
if !lastRowIsNull {
// don't use col.GetString since it will cause DATA RACE
lastRowDatum.SetString(string(col.GetBytes(numRows-1)), tp.Collate)
}
default:
err = errors.New(fmt.Sprintf("invalid eval type %v", eType))
}
Expand Down
51 changes: 51 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,54 @@ func (s *testSuiteAgg) TestPR15242ShallowCopy(c *C) {
tk.MustQuery(`select max(JSON_EXTRACT(a, '$.score')) as max_score,JSON_EXTRACT(a,'$.id') as id from t group by id order by id;`).Check(testkit.Rows("233 1", "233 2", "233 3"))

}

func (s *testSuiteAgg) TestIssue15690(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.Se.GetSessionVars().MaxChunkSize = 2
// check for INT type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a int);`)
tk.MustExec(`insert into t values(null),(null);`)
tk.MustExec(`insert into t values(0),(2),(2),(4),(8);`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "0", "2", "4", "8"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for FLOAT type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a float);`)
tk.MustExec(`insert into t values(null),(null),(null),(null);`)
tk.MustExec(`insert into t values(1.1),(1.1);`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "1.1"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for DECIMAL type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a decimal(5,1));`)
tk.MustExec(`insert into t values(null),(null),(null);`)
tk.MustExec(`insert into t values(1.1),(2.2),(2.2);`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "1.1", "2.2"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for DATETIME type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a datetime);`)
tk.MustExec(`insert into t values(null);`)
tk.MustExec(`insert into t values("2019-03-20 21:50:00"),("2019-03-20 21:50:01"), ("2019-03-20 21:50:00");`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "2019-03-20 21:50:00", "2019-03-20 21:50:01"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for JSON type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a json);`)
tk.MustExec(`insert into t values(null),(null),(null),(null);`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for char type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a char);`)
tk.MustExec(`insert into t values(null),(null),(null),(null);`)
tk.MustExec(`insert into t values('a'),('b');`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "a", "b"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))
}
33 changes: 30 additions & 3 deletions executor/executor_required_rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,25 +713,40 @@ func (s *testExecSuite) TestMergeJoinRequiredRows(c *C) {

func genTestChunk4VecGroupChecker(chkRows []int, sameNum int) (expr []expression.Expression, inputs []*chunk.Chunk) {
chkNum := len(chkRows)
numRows := 0
inputs = make([]*chunk.Chunk, chkNum)
fts := make([]*types.FieldType, 1)
fts[0] = types.NewFieldType(mysql.TypeLonglong)
for i := 0; i < chkNum; i++ {
inputs[i] = chunk.New(fts, chkRows[i], chkRows[i])
numRows += chkRows[i]
}
var numGroups int
if numRows%sameNum == 0 {
numGroups = numRows / sameNum
} else {
numGroups = numRows/sameNum + 1
}

rand.Seed(time.Now().Unix())
nullPos := rand.Intn(numGroups)
cnt := 0
val := 0
val := rand.Int63()
for i := 0; i < chkNum; i++ {
col := inputs[i].Column(0)
col.ResizeInt64(chkRows[i], false)
i64s := col.Int64s()
for j := 0; j < chkRows[i]; j++ {
if cnt == sameNum {
val++
val = rand.Int63()
cnt = 0
nullPos--
}
if nullPos == 0 {
col.SetNull(j, true)
} else {
i64s[j] = val
}
i64s[j] = int64(val)
cnt++
}
}
Expand Down Expand Up @@ -775,6 +790,18 @@ func (s *testExecSuite) TestVecGroupChecker(c *C) {
expectedFlag: []bool{false, false},
sameNum: 1,
},
{
chunkRows: []int{2, 2},
expectedGroups: 2,
expectedFlag: []bool{false, false},
sameNum: 2,
},
{
chunkRows: []int{2, 2},
expectedGroups: 1,
expectedFlag: []bool{false, true},
sameNum: 4,
},
}

ctx := mock.NewContext()
Expand Down

0 comments on commit 0d10f91

Please sign in to comment.