Skip to content

Commit

Permalink
expression: add agg function var_pop()
Browse files Browse the repository at this point in the history
  • Loading branch information
mccxj committed Sep 17, 2018
1 parent b30dbd0 commit 99c7f3a
Show file tree
Hide file tree
Showing 15 changed files with 413 additions and 24 deletions.
2 changes: 2 additions & 0 deletions ast/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,8 @@ const (
AggFuncBitXor = "bit_xor"
// AggFuncBitAnd is the name of bit_and function.
AggFuncBitAnd = "bit_and"
// AggFuncVarPop is the name of var_pop function.
AggFuncVarPop = "var_pop"
)

// AggregateFuncExpr represents aggregate function expression.
Expand Down
4 changes: 4 additions & 0 deletions executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ var (

// All the AggFunc implementations for "BIT_AND" are listed here.
_ AggFunc = (*bitAndUint64)(nil)

// All the AggFunc implementations for "VAR_POP" are listed here.
_ AggFunc = (*varPopOriginal4Float64)(nil)
_ AggFunc = (*varPopPartial4Float64)(nil)
)

// PartialResult represents data structure to store the partial result for the
Expand Down
25 changes: 25 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal
return buildBitXor(aggFuncDesc, ordinal)
case ast.AggFuncBitAnd:
return buildBitAnd(aggFuncDesc, ordinal)
case ast.AggFuncVarPop:
return buildVarPop(aggFuncDesc, ordinal)
}
return nil
}
Expand Down Expand Up @@ -312,3 +314,26 @@ func buildBitAnd(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
}
return &bitAndUint64{baseBitAggFunc{base}}
}

// buildVarPop builds the AggFunc implementation for function "VAR_POP".
func buildVarPop(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
switch aggFuncDesc.Mode {
// Build var_pop functions which consume the original data and remove the
// duplicated input of the same group.
case aggregation.DedupMode:
return nil // not implemented yet.
// Build var_pop functions which consume the original data and update their
// partial results.
case aggregation.CompleteMode, aggregation.Partial1Mode:
return &varPopOriginal4Float64{baseVarPopFloat64{base}}
// Build var_pop functions which consume the partial result of other avg
// functions and update their partial results.
case aggregation.Partial2Mode, aggregation.FinalMode:
return &varPopPartial4Float64{baseVarPopFloat64{base}}
}
return nil
}
122 changes: 122 additions & 0 deletions executor/aggfuncs/func_var_pop.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
"github.com/pkg/errors"
)

// All the following avg function implementations return the float64 result,
// which store the partial results in "partialResult4VarPopFloat64".
//
// "baseVarPopFloat64" is wrapped by:
// - "varPopOriginal4Float64"
// - "varPopPartial4Float64"
type baseVarPopFloat64 struct {
baseAggFunc
}

type partialResult4VarPopFloat64 struct {
sum float64
squareSum float64
count int64
}

func (e *baseVarPopFloat64) AllocPartialResult() PartialResult {
return (PartialResult)(&partialResult4VarPopFloat64{})
}

func (e *baseVarPopFloat64) ResetPartialResult(pr PartialResult) {
p := (*partialResult4VarPopFloat64)(pr)
p.sum, p.squareSum, p.count = 0, 0, 0
}

func (e *baseVarPopFloat64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4VarPopFloat64)(pr)
if p.count == 0 {
chk.AppendNull(e.ordinal)
} else {
chk.AppendFloat64(e.ordinal, (p.squareSum-p.sum*p.sum/float64(p.count))/float64(p.count))
}
return nil
}

type varPopOriginal4Float64 struct {
baseVarPopFloat64
}

func (e *varPopOriginal4Float64) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4VarPopFloat64)(pr)
for _, row := range rowsInGroup {
input, isNull, err := e.args[0].EvalReal(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}

p.sum += input
p.squareSum += input * input
p.count++
}
return nil
}

type varPopPartial4Float64 struct {
baseVarPopFloat64
}

func (e *varPopPartial4Float64) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4VarPopFloat64)(pr)
for _, row := range rowsInGroup {
inputSquareSum, isNull, err := e.args[2].EvalReal(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}

inputSum, isNull, err := e.args[1].EvalReal(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}

inputCount, isNull, err := e.args[0].EvalInt(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}
p.sum += inputSum
p.squareSum += inputSquareSum
p.count += inputCount
}
return nil
}

func (e *varPopPartial4Float64) MergePartialResult(sctx sessionctx.Context, src PartialResult, dst PartialResult) error {
p1, p2 := (*partialResult4VarPopFloat64)(src), (*partialResult4VarPopFloat64)(dst)
p2.sum += p1.sum
p2.squareSum += p1.squareSum
p2.count += p1.count
return nil
}
16 changes: 10 additions & 6 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ func (s *testSuite) TestAggregation(c *C) {
result.Check(testkit.Rows("1", "3", "4"))
result = tk.MustQuery("select avg(c) as a from t group by d order by a")
result.Check(testkit.Rows("1.0000", "2.0000", "2.5000"))
result = tk.MustQuery("select var_pop(c) as a from t group by d order by a")
result.Check(testkit.Rows("0", "1", "2.25"))
result = tk.MustQuery("select d, d + 1 from t group by d order by d")
result.Check(testkit.Rows("1 2", "2 3", "3 4"))
result = tk.MustQuery("select count(*) from t")
Expand Down Expand Up @@ -189,6 +191,8 @@ func (s *testSuite) TestAggregation(c *C) {
tk.MustExec("insert into t1 (a, b) values (1, 1),(2, 2),(3, 3),(1, 4),(3, 5)")
result = tk.MustQuery("select avg(b) from (select * from t1) t group by a order by a")
result.Check(testkit.Rows("2.5000", "2.0000", "4.0000"))
result = tk.MustQuery("select var_pop(b) from (select * from t1) t group by a order by a")
result.Check(testkit.Rows("2.25", "0", "1"))
result = tk.MustQuery("select sum(b) from (select * from t1) t group by a order by a")
result.Check(testkit.Rows("5", "2", "8"))
result = tk.MustQuery("select count(b) from (select * from t1) t group by a order by a")
Expand Down Expand Up @@ -303,15 +307,15 @@ func (s *testSuite) TestAggregation(c *C) {
tk.MustQuery("select sum(tags->'$.i') from t").Check(testkit.Rows("14"))

// test agg with empty input
result = tk.MustQuery("select id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95) from t where null")
result.Check(testkit.Rows("<nil> 0 <nil> <nil> 0 18446744073709551615 0 <nil> <nil> <nil>"))
result = tk.MustQuery("select id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95), var_pop(95) from t where null")
result.Check(testkit.Rows("<nil> 0 <nil> <nil> 0 18446744073709551615 0 <nil> <nil> <nil> <nil>"))
tk.MustExec("truncate table t")
tk.MustExec("create table s(id int)")
result = tk.MustQuery("select t.id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95) from t left join s on t.id = s.id")
result.Check(testkit.Rows("<nil> 0 <nil> <nil> 0 18446744073709551615 0 <nil> <nil> <nil>"))
result = tk.MustQuery("select t.id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95), var_pop(95) from t left join s on t.id = s.id")
result.Check(testkit.Rows("<nil> 0 <nil> <nil> 0 18446744073709551615 0 <nil> <nil> <nil> <nil>"))
tk.MustExec(`insert into t values (1, '{"i": 1, "n": "n1"}')`)
result = tk.MustQuery("select t.id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95) from t left join s on t.id = s.id")
result.Check(testkit.Rows("1 1 95 95.0000 95 95 95 95 95 95"))
result = tk.MustQuery("select t.id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95), var_pop(95) from t left join s on t.id = s.id")
result.Check(testkit.Rows("1 1 95 95.0000 95 95 95 95 95 95 0"))
tk.MustExec("set @@tidb_hash_join_concurrency=5")

// test agg bit col
Expand Down
6 changes: 6 additions & 0 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,12 @@ func (b *executorBuilder) buildHashAgg(v *plan.PhysicalHashAgg) Executor {
ordinal = append(ordinal, partialOrdinal+1)
partialOrdinal++
}
if aggDesc.Name == ast.AggFuncVarPop {
ordinal = append(ordinal, partialOrdinal+1)
partialOrdinal++
ordinal = append(ordinal, partialOrdinal+2)
partialOrdinal++
}
finalDesc := aggDesc.Split(ordinal)
e.PartialAggFuncs = append(e.PartialAggFuncs, aggfuncs.Build(b.ctx, aggDesc, i))
e.FinalAggFuncs = append(e.FinalAggFuncs, aggfuncs.Build(b.ctx, finalDesc, i))
Expand Down
2 changes: 2 additions & 0 deletions expression/aggregation/agg_to_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ func AggFuncToPBExpr(sc *stmtctx.StatementContext, client kv.Client, aggFunc *Ag
tp = tipb.ExprType_Agg_BitXor
case ast.AggFuncBitAnd:
tp = tipb.ExprType_Agg_BitAnd
case ast.AggFuncVarPop:
tp = tipb.ExprType_VarPop
}
if !client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tp)) {
return nil
Expand Down
8 changes: 6 additions & 2 deletions expression/aggregation/aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ func NewDistAggFunc(expr *tipb.Expr, fieldTps []*types.FieldType, sc *stmtctx.St
return &bitXorFunction{aggFunction: newAggFunc(ast.AggFuncBitXor, args, false)}, nil
case tipb.ExprType_Agg_BitAnd:
return &bitAndFunction{aggFunction: newAggFunc(ast.AggFuncBitAnd, args, false)}, nil
case tipb.ExprType_VarPop:
return &varPopFunction{aggFunction: newAggFunc(ast.AggFuncVarPop, args, false)}, nil
}
return nil, errors.Errorf("Unknown aggregate function type %v", expr.Tp)
}
Expand All @@ -97,6 +99,7 @@ type AggEvaluateContext struct {
DistinctChecker *distinctChecker
Count int64
Value types.Datum
Extra types.Datum
Buffer *bytes.Buffer // Buffer is used for group_concat.
GotFirstRow bool // It will check if the agg has met the first row key.
}
Expand Down Expand Up @@ -227,14 +230,15 @@ func (af *aggFunction) Clone(ctx sessionctx.Context) Aggregation {

// NeedCount indicates whether the aggregate function should record count.
func NeedCount(name string) bool {
return name == ast.AggFuncCount || name == ast.AggFuncAvg
return name == ast.AggFuncCount || name == ast.AggFuncAvg || name == ast.AggFuncVarPop
}

// NeedValue indicates whether the aggregate function should record value.
func NeedValue(name string) bool {
switch name {
case ast.AggFuncSum, ast.AggFuncAvg, ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin,
ast.AggFuncGroupConcat, ast.AggFuncBitOr, ast.AggFuncBitAnd, ast.AggFuncBitXor:
ast.AggFuncGroupConcat, ast.AggFuncBitOr, ast.AggFuncBitAnd, ast.AggFuncBitXor,
ast.AggFuncVarPop:
return true
default:
return false
Expand Down
55 changes: 55 additions & 0 deletions expression/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,58 @@ func (s *testAggFuncSuit) TestMaxMin(c *C) {
partialResult := minFunc.GetPartialResult(minEvalCtx)
c.Assert(partialResult[0].GetInt64(), Equals, int64(1))
}

func (s *testAggFuncSuit) TestVarPop(c *C) {
col := &expression.Column{
Index: 0,
RetType: types.NewFieldType(mysql.TypeLonglong),
}
ctx := mock.NewContext()
varPopFunc := NewAggFuncDesc(s.ctx, ast.AggFuncVarPop, []expression.Expression{col}, false).GetAggFunc(ctx)
evalCtx := varPopFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

result := varPopFunc.GetResult(evalCtx)
c.Assert(result.IsNull(), IsTrue)

for _, row := range s.rows {
err := varPopFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row)
c.Assert(err, IsNil)
}
result = varPopFunc.GetResult(evalCtx)
c.Assert(types.CompareFloat64(result.GetFloat64(), 561) == 0, IsTrue)
err := varPopFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow)
c.Assert(err, IsNil)
result = varPopFunc.GetResult(evalCtx)
c.Assert(types.CompareFloat64(result.GetFloat64(), 561) == 0, IsTrue)
}

func (s *testAggFuncSuit) TestVarPopFinalMode(c *C) {
rows := make([][]types.Datum, 0, 100)
for i := 1; i <= 100; i++ {
rows = append(rows, types.MakeDatums(i, float64(i*i), float64(i*i*i)))
}
ctx := mock.NewContext()
cntCol := &expression.Column{
Index: 0,
RetType: types.NewFieldType(mysql.TypeLonglong),
}
sumCol := &expression.Column{
Index: 1,
RetType: types.NewFieldType(mysql.TypeDouble),
}
squareSumCol := &expression.Column{
Index: 2,
RetType: types.NewFieldType(mysql.TypeDouble),
}
varPopFunc := NewAggFuncDesc(s.ctx, ast.AggFuncVarPop, []expression.Expression{cntCol, sumCol, squareSumCol}, false)
varPopFunc.Mode = FinalMode
avgFunc := varPopFunc.GetAggFunc(ctx)
evalCtx := avgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

for _, row := range rows {
err := avgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, chunk.MutRowFromDatums(row).ToRow())
c.Assert(err, IsNil)
}
result := avgFunc.GetResult(evalCtx)
c.Assert(types.CompareFloat64(result.GetFloat64(), 561) == 0, IsTrue)
}
Loading

0 comments on commit 99c7f3a

Please sign in to comment.