diff --git a/executor/aggregate.go b/executor/aggregate.go index d46cb7da67ddd..35df9e0a52aa9 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -1111,8 +1111,15 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express return err } - previousIsNull := col.IsNull(0) var firstRowDatum, lastRowDatum types.Datum + firstRowIsNull, lastRowIsNull := col.IsNull(0), col.IsNull(numRows-1) + if firstRowIsNull { + firstRowDatum.SetNull() + } + if lastRowIsNull { + lastRowDatum.SetNull() + } + previousIsNull := firstRowIsNull switch eType { case types.ETInt: vals := col.Int64s() @@ -1128,8 +1135,12 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express } previousIsNull = isNull } - firstRowDatum.SetInt64(vals[0]) - lastRowDatum.SetInt64(vals[numRows-1]) + if !firstRowIsNull { + firstRowDatum.SetInt64(vals[0]) + } + if !lastRowIsNull { + lastRowDatum.SetInt64(vals[numRows-1]) + } case types.ETReal: vals := col.Float64s() for i := 1; i < numRows; i++ { @@ -1144,8 +1155,12 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express } previousIsNull = isNull } - firstRowDatum.SetFloat64(vals[0]) - lastRowDatum.SetFloat64(vals[numRows-1]) + if !firstRowIsNull { + firstRowDatum.SetFloat64(vals[0]) + } + if !lastRowIsNull { + lastRowDatum.SetFloat64(vals[numRows-1]) + } case types.ETDecimal: vals := col.Decimals() for i := 1; i < numRows; i++ { @@ -1160,10 +1175,16 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express } previousIsNull = isNull } - // make a copy to avoid DATA RACE - firstDatum, lastDatum := vals[0], vals[numRows-1] - firstRowDatum.SetMysqlDecimal(&firstDatum) - lastRowDatum.SetMysqlDecimal(&lastDatum) + if !firstRowIsNull { + // make a copy to avoid DATA RACE + firstDatum := vals[0] + firstRowDatum.SetMysqlDecimal(&firstDatum) + } + if !lastRowIsNull { + // make a copy to avoid DATA RACE + lastDatum := vals[numRows-1] + lastRowDatum.SetMysqlDecimal(&lastDatum) + } case types.ETDatetime, types.ETTimestamp: vals := col.Times() for i := 1; i < numRows; i++ { @@ -1178,8 +1199,12 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express } previousIsNull = isNull } - firstRowDatum.SetMysqlTime(vals[0]) - lastRowDatum.SetMysqlTime(vals[numRows-1]) + if !firstRowIsNull { + firstRowDatum.SetMysqlTime(vals[0]) + } + if !lastRowIsNull { + lastRowDatum.SetMysqlTime(vals[numRows-1]) + } case types.ETDuration: vals := col.GoDurations() for i := 1; i < numRows; i++ { @@ -1194,24 +1219,44 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express } previousIsNull = isNull } - firstRowDatum.SetMysqlDuration(types.Duration{Duration: vals[0], Fsp: int8(item.GetType().Decimal)}) - lastRowDatum.SetMysqlDuration(types.Duration{Duration: vals[numRows-1], Fsp: int8(item.GetType().Decimal)}) + if !firstRowIsNull { + firstRowDatum.SetMysqlDuration(types.Duration{Duration: vals[0], Fsp: int8(item.GetType().Decimal)}) + } + if !lastRowIsNull { + lastRowDatum.SetMysqlDuration(types.Duration{Duration: vals[numRows-1], Fsp: int8(item.GetType().Decimal)}) + } case types.ETJson: - previousKey := col.GetJSON(0) + var previousKey, key json.BinaryJSON + if !previousIsNull { + previousKey = col.GetJSON(0) + } for i := 1; i < numRows; i++ { - key := col.GetJSON(i) isNull := col.IsNull(i) + if !isNull { + key = col.GetJSON(i) + } if e.sameGroup[i] { - if isNull != previousIsNull || json.CompareBinary(previousKey, key) != 0 { + if isNull == previousIsNull { + if !isNull && json.CompareBinary(previousKey, key) != 0 { + e.sameGroup[i] = false + } + } else { e.sameGroup[i] = false } } - previousKey = key + if !isNull { + previousKey = key + } previousIsNull = isNull } - // make a copy to avoid DATA RACE - firstRowDatum.SetMysqlJSON(col.GetJSON(0).Copy()) - lastRowDatum.SetMysqlJSON(col.GetJSON(numRows - 1).Copy()) + if !firstRowIsNull { + // make a copy to avoid DATA RACE + firstRowDatum.SetMysqlJSON(col.GetJSON(0).Copy()) + } + if !lastRowIsNull { + // make a copy to avoid DATA RACE + lastRowDatum.SetMysqlJSON(col.GetJSON(numRows - 1).Copy()) + } case types.ETString: previousKey := codec.ConvertByCollationStr(col.GetString(0), tp) for i := 1; i < numRows; i++ { @@ -1225,9 +1270,14 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express previousKey = key previousIsNull = isNull } - // don't use col.GetString since it will cause DATA RACE - firstRowDatum.SetString(string(col.GetBytes(0)), tp.Collate) - lastRowDatum.SetString(string(col.GetBytes(numRows-1)), tp.Collate) + if !firstRowIsNull { + // don't use col.GetString since it will cause DATA RACE + firstRowDatum.SetString(string(col.GetBytes(0)), tp.Collate) + } + if !lastRowIsNull { + // don't use col.GetString since it will cause DATA RACE + lastRowDatum.SetString(string(col.GetBytes(numRows-1)), tp.Collate) + } default: err = errors.New(fmt.Sprintf("invalid eval type %v", eType)) } diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 12fbc9700cc3a..0219cdd0f5182 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -844,3 +844,54 @@ func (s *testSuiteAgg) TestPR15242ShallowCopy(c *C) { tk.MustQuery(`select max(JSON_EXTRACT(a, '$.score')) as max_score,JSON_EXTRACT(a,'$.id') as id from t group by id order by id;`).Check(testkit.Rows("233 1", "233 2", "233 3")) } + +func (s *testSuiteAgg) TestIssue15690(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.Se.GetSessionVars().MaxChunkSize = 2 + // check for INT type + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a int);`) + tk.MustExec(`insert into t values(null),(null);`) + tk.MustExec(`insert into t values(0),(2),(2),(4),(8);`) + tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("", "0", "2", "4", "8")) + c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0)) + + // check for FLOAT type + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a float);`) + tk.MustExec(`insert into t values(null),(null),(null),(null);`) + tk.MustExec(`insert into t values(1.1),(1.1);`) + tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("", "1.1")) + c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0)) + + // check for DECIMAL type + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a decimal(5,1));`) + tk.MustExec(`insert into t values(null),(null),(null);`) + tk.MustExec(`insert into t values(1.1),(2.2),(2.2);`) + tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("", "1.1", "2.2")) + c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0)) + + // check for DATETIME type + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a datetime);`) + tk.MustExec(`insert into t values(null);`) + tk.MustExec(`insert into t values("2019-03-20 21:50:00"),("2019-03-20 21:50:01"), ("2019-03-20 21:50:00");`) + tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("", "2019-03-20 21:50:00", "2019-03-20 21:50:01")) + c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0)) + + // check for JSON type + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a json);`) + tk.MustExec(`insert into t values(null),(null),(null),(null);`) + tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("")) + c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0)) + + // check for char type + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a char);`) + tk.MustExec(`insert into t values(null),(null),(null),(null);`) + tk.MustExec(`insert into t values('a'),('b');`) + tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("", "a", "b")) + c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0)) +} diff --git a/executor/executor_required_rows_test.go b/executor/executor_required_rows_test.go index 367f4d9426a47..56527396346a9 100644 --- a/executor/executor_required_rows_test.go +++ b/executor/executor_required_rows_test.go @@ -713,25 +713,40 @@ func (s *testExecSuite) TestMergeJoinRequiredRows(c *C) { func genTestChunk4VecGroupChecker(chkRows []int, sameNum int) (expr []expression.Expression, inputs []*chunk.Chunk) { chkNum := len(chkRows) + numRows := 0 inputs = make([]*chunk.Chunk, chkNum) fts := make([]*types.FieldType, 1) fts[0] = types.NewFieldType(mysql.TypeLonglong) for i := 0; i < chkNum; i++ { inputs[i] = chunk.New(fts, chkRows[i], chkRows[i]) + numRows += chkRows[i] + } + var numGroups int + if numRows%sameNum == 0 { + numGroups = numRows / sameNum + } else { + numGroups = numRows/sameNum + 1 } + rand.Seed(time.Now().Unix()) + nullPos := rand.Intn(numGroups) cnt := 0 - val := 0 + val := rand.Int63() for i := 0; i < chkNum; i++ { col := inputs[i].Column(0) col.ResizeInt64(chkRows[i], false) i64s := col.Int64s() for j := 0; j < chkRows[i]; j++ { if cnt == sameNum { - val++ + val = rand.Int63() cnt = 0 + nullPos-- + } + if nullPos == 0 { + col.SetNull(j, true) + } else { + i64s[j] = val } - i64s[j] = int64(val) cnt++ } } @@ -775,6 +790,18 @@ func (s *testExecSuite) TestVecGroupChecker(c *C) { expectedFlag: []bool{false, false}, sameNum: 1, }, + { + chunkRows: []int{2, 2}, + expectedGroups: 2, + expectedFlag: []bool{false, false}, + sameNum: 2, + }, + { + chunkRows: []int{2, 2}, + expectedGroups: 1, + expectedFlag: []bool{false, true}, + sameNum: 4, + }, } ctx := mock.NewContext()