Skip to content

Commit

Permalink
executor: add builtin aggregate function json_arrayagg (#19957)
Browse files Browse the repository at this point in the history
  • Loading branch information
arthuryangcs authored Aug 4, 2021
1 parent 15ca386 commit 853c41e
Show file tree
Hide file tree
Showing 11 changed files with 371 additions and 5 deletions.
60 changes: 60 additions & 0 deletions executor/aggfuncs/aggfunc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,9 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
if p.funcName == ast.AggFuncApproxCountDistinct {
resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeString)}, 1)
}
if p.funcName == ast.AggFuncJsonArrayagg {
resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeJSON)}, 1)
}

// update partial result.
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
Expand All @@ -402,6 +405,9 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
if p.funcName == ast.AggFuncApproxCountDistinct {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeString))
}
if p.funcName == ast.AggFuncJsonArrayagg {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON))
}
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0]))
Expand All @@ -426,6 +432,9 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
if p.funcName == ast.AggFuncApproxCountDistinct {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeString))
}
if p.funcName == ast.AggFuncJsonArrayagg {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON))
}
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1]))
Expand All @@ -435,6 +444,9 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
if p.funcName == ast.AggFuncApproxCountDistinct {
resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1)
}
if p.funcName == ast.AggFuncJsonArrayagg {
resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeJSON)}, 1)
}
resultChk.Reset()
err = finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
c.Assert(err, IsNil)
Expand All @@ -443,6 +455,9 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
if p.funcName == ast.AggFuncApproxCountDistinct {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeLonglong))
}
if p.funcName == ast.AggFuncJsonArrayagg {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON))
}
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[2])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[2]))
Expand Down Expand Up @@ -687,6 +702,51 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0]))
}

func (s *testSuite) testAggFuncWithoutDistinct(c *C, p aggTest) {
srcChk := p.genSrcChk()

args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
if p.funcName == ast.AggFuncApproxPercentile {
args = append(args, &expression.Constant{Value: types.NewIntDatum(50), RetType: types.NewFieldType(mysql.TypeLong)})
}
desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false)
c.Assert(err, IsNil)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc := aggfuncs.Build(s.ctx, desc, 0)
finalPr, _ := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)

iter := chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
_, err = finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
c.Assert(err, IsNil)
}
p.messUpChunk(srcChk)
err = finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
c.Assert(err, IsNil)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1]))

// test the empty input
resultChk.Reset()
finalFunc.ResetPartialResult(finalPr)
err = finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
c.Assert(err, IsNil)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0]))
}

func (s *testSuite) testAggMemFunc(c *C, p aggMemTest) {
srcChk := p.aggTest.genSrcChk()

Expand Down
3 changes: 3 additions & 0 deletions executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ var (
// All the AggFunc implementations for "BIT_AND" are listed here.
_ AggFunc = (*bitAndUint64)(nil)

// All the AggFunc implementations for "JSON_ARRAYAGG" are listed here
_ AggFunc = (*jsonArrayagg)(nil)

// All the AggFunc implementations for "JSON_OBJECTAGG" are listed here
_ AggFunc = (*jsonObjectAgg)(nil)
)
Expand Down
16 changes: 16 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal
return buildVarPop(aggFuncDesc, ordinal)
case ast.AggFuncStddevPop:
return buildStdDevPop(aggFuncDesc, ordinal)
case ast.AggFuncJsonArrayagg:
return buildJSONArrayagg(aggFuncDesc, ordinal)
case ast.AggFuncJsonObjectAgg:
return buildJSONObjectAgg(aggFuncDesc, ordinal)
case ast.AggFuncApproxCountDistinct:
Expand Down Expand Up @@ -615,6 +617,20 @@ func buildStddevSamp(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc
}
}

// buildJSONArrayagg builds the AggFunc implementation for function "json_arrayagg".
func buildJSONArrayagg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
default:
return &jsonArrayagg{base}
}
}

// buildJSONObjectAgg builds the AggFunc implementation for function "json_objectagg".
func buildJSONObjectAgg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
Expand Down
103 changes: 103 additions & 0 deletions executor/aggfuncs/func_json_arrayagg.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright 2020 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
)

const (
// DefPartialResult4JsonArrayagg is the size of partialResult4JsonArrayagg
DefPartialResult4JsonArrayagg = int64(unsafe.Sizeof(partialResult4JsonArrayagg{}))
)

type jsonArrayagg struct {
baseAggFunc
}

type partialResult4JsonArrayagg struct {
entries []interface{}
}

func (e *jsonArrayagg) AllocPartialResult() (pr PartialResult, memDelta int64) {
p := partialResult4JsonArrayagg{}
p.entries = make([]interface{}, 0)
return PartialResult(&p), DefPartialResult4JsonArrayagg + DefSliceSize
}

func (e *jsonArrayagg) ResetPartialResult(pr PartialResult) {
p := (*partialResult4JsonArrayagg)(pr)
p.entries = p.entries[:0]
}

func (e *jsonArrayagg) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4JsonArrayagg)(pr)
if len(p.entries) == 0 {
chk.AppendNull(e.ordinal)
return nil
}

// appendBinary does not support some type such as uint8、types.time,so convert is needed here
for idx, val := range p.entries {
switch x := val.(type) {
case *types.MyDecimal:
float64Val, err := x.ToFloat64()
if err != nil {
return errors.Trace(err)
}
p.entries[idx] = float64Val
case []uint8, types.Time, types.Duration:
strVal, err := types.ToString(x)
if err != nil {
return errors.Trace(err)
}
p.entries[idx] = strVal
}
}

chk.AppendJSON(e.ordinal, json.CreateBinary(p.entries))
return nil
}

func (e *jsonArrayagg) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) {
p := (*partialResult4JsonArrayagg)(pr)
for _, row := range rowsInGroup {
item, err := e.args[0].Eval(row)
if err != nil {
return 0, errors.Trace(err)
}

realItem := item.Clone().GetValue()
switch x := realItem.(type) {
case nil, bool, int64, uint64, float64, string, json.BinaryJSON, *types.MyDecimal, []uint8, types.Time, types.Duration:
p.entries = append(p.entries, realItem)
memDelta += getValMemDelta(realItem)
default:
return 0, json.ErrUnsupportedSecondArgumentType.GenWithStackByArgs(x)
}
}
return memDelta, nil
}

func (e *jsonArrayagg) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) (memDelta int64, err error) {
p1, p2 := (*partialResult4JsonArrayagg)(src), (*partialResult4JsonArrayagg)(dst)
p2.entries = append(p2.entries, p1.entries...)
return 0, nil
}
139 changes: 139 additions & 0 deletions executor/aggfuncs/func_json_arrayagg_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright 2020 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs_test

import (
. "github.com/pingcap/check"
"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
)

func (s *testSuite) TestMergePartialResult4JsonArrayagg(c *C) {
typeList := []byte{mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeString, mysql.TypeJSON}

var tests []aggTest
numRows := 5
for _, argType := range typeList {
entries1 := make([]interface{}, 0)
entries2 := make([]interface{}, 0)
entries3 := make([]interface{}, 0)

genFunc := getDataGenFunc(types.NewFieldType(argType))

for m := 0; m < numRows; m++ {
arg := genFunc(m)
entries1 = append(entries1, arg.GetValue())
}
// to adapt the `genSrcChk` Chunk format
entries1 = append(entries1, nil)

for m := 2; m < numRows; m++ {
arg := genFunc(m)
entries2 = append(entries2, arg.GetValue())
}
// to adapt the `genSrcChk` Chunk format
entries2 = append(entries2, nil)

entries3 = append(entries3, entries1...)
entries3 = append(entries3, entries2...)

tests = append(tests, buildAggTester(ast.AggFuncJsonArrayagg, argType, numRows, json.CreateBinary(entries1), json.CreateBinary(entries2), json.CreateBinary(entries3)))
}

for _, test := range tests {
s.testMergePartialResult(c, test)
}
}

func (s *testSuite) TestJsonArrayagg(c *C) {
typeList := []byte{mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeString, mysql.TypeJSON}

var tests []aggTest
numRows := 5

for _, argType := range typeList {
entries := make([]interface{}, 0)

genFunc := getDataGenFunc(types.NewFieldType(argType))

for m := 0; m < numRows; m++ {
arg := genFunc(m)
entries = append(entries, arg.GetValue())
}
// to adapt the `genSrcChk` Chunk format
entries = append(entries, nil)

tests = append(tests, buildAggTester(ast.AggFuncJsonArrayagg, argType, numRows, nil, json.CreateBinary(entries)))
}

for _, test := range tests {
s.testAggFuncWithoutDistinct(c, test)
}
}

func jsonArrayaggMemDeltaGens(srcChk *chunk.Chunk, dataType *types.FieldType) (memDeltas []int64, err error) {
memDeltas = make([]int64, 0)
for i := 0; i < srcChk.NumRows(); i++ {
row := srcChk.GetRow(i)
if row.IsNull(0) {
memDeltas = append(memDeltas, aggfuncs.DefInterfaceSize)
continue
}

memDelta := int64(0)
memDelta += aggfuncs.DefInterfaceSize
switch dataType.Tp {
case mysql.TypeLonglong:
memDelta += aggfuncs.DefUint64Size
case mysql.TypeDouble:
memDelta += aggfuncs.DefFloat64Size
case mysql.TypeString:
val := row.GetString(0)
memDelta += int64(len(val))
case mysql.TypeJSON:
val := row.GetJSON(0)
// +1 for the memory usage of the TypeCode of json
memDelta += int64(len(val.Value) + 1)
case mysql.TypeDuration:
memDelta += aggfuncs.DefDurationSize
case mysql.TypeDate:
memDelta += aggfuncs.DefTimeSize
case mysql.TypeNewDecimal:
memDelta += aggfuncs.DefMyDecimalSize
default:
return memDeltas, errors.Errorf("unsupported type - %v", dataType.Tp)
}
memDeltas = append(memDeltas, memDelta)
}
return memDeltas, nil
}

func (s *testSuite) TestMemJsonArrayagg(c *C) {
typeList := []byte{mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeString, mysql.TypeJSON}

var tests []aggMemTest
numRows := 5
for _, argType := range typeList {
tests = append(tests, buildAggMemTester(ast.AggFuncJsonArrayagg, argType, numRows, aggfuncs.DefPartialResult4JsonArrayagg+aggfuncs.DefSliceSize, jsonArrayaggMemDeltaGens, false))
}

for _, test := range tests {
s.testAggMemFunc(c, test)
}
}
2 changes: 2 additions & 0 deletions expression/aggregation/agg_to_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ func AggFuncToPBExpr(sc *stmtctx.StatementContext, client kv.Client, aggFunc *Ag
tp = tipb.ExprType_Agg_BitAnd
case ast.AggFuncVarPop:
tp = tipb.ExprType_VarPop
case ast.AggFuncJsonArrayagg:
tp = tipb.ExprType_JsonArrayAgg
case ast.AggFuncJsonObjectAgg:
tp = tipb.ExprType_JsonObjectAgg
case ast.AggFuncStddevPop:
Expand Down
Loading

0 comments on commit 853c41e

Please sign in to comment.