From 72a5c197c66699edcbfbecc44a4308635bc57deb Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Thu, 18 Aug 2016 23:18:19 +0000 Subject: [PATCH] sql: simplify/optimize/fix the aggregate functions. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prior to this patch, `varianceAggregate`, `sumAggregate` (and indirectly `avg` and `stddev` aggregates which used the former) tried to be "smart" and share code between various argument types. However this actually caused extra complexity and run-time overhead. Indeed, every time the Add method was called: - a type switch was performed on the argument (although every value added to a given variance aggregator will always be of the same type) - a check was made whether an implementation for its argument's type had been instantiated already, to instantiate it if needed (although given the argument type is known after type checking, this instantiation could have been performed prior to the first Add call) This patch addresses this shortcoming by separating implementations by argument type. In addition, minor optimizations were also implemented: - the error return value was removed, since there is no error condition. - the `sum` (and thus `avg`) aggregates now use a single `int64` as long as the sum does not overflow. Also a long-standing bug with COUNT was fixed (the COUNT of a tuple of NULL values is 1, not 0). ``` name old time/op new time/op delta AvgAggregateInt1K-6 57.1µs ± 1% 55.3µs ± 0% -3.27% (p=0.000 n=9+8) AvgAggregateSmallInt1K-6 60.1µs ± 1% 17.8µs ± 2% -70.33% (p=0.000 n=9+10) AvgAggregateFloat1K-6 10.8µs ± 3% 8.8µs ± 1% -19.02% (p=0.000 n=10+9) AvgAggregateDecimal1K-6 252µs ± 3% 249µs ± 3% ~ (p=0.105 n=10+10) CountAggregate1K-6 5.47µs ± 2% 3.83µs ± 1% -29.99% (p=0.000 n=10+10) SumAggregateInt1K-6 52.1µs ± 1% 49.2µs ± 1% -5.59% (p=0.000 n=10+10) SumAggregateSmallInt1K-6 55.0µs ± 1% 13.3µs ± 2% -75.85% (p=0.000 n=10+10) SumAggregateFloat1K-6 7.04µs ± 5% 5.22µs ± 1% -25.79% (p=0.000 n=10+10) SumAggregateDecimal1K-6 246µs ± 2% 237µs ± 2% -3.77% (p=0.000 n=10+10) MaxAggregateInt1K-6 8.92µs ± 1% 8.48µs ± 2% -4.99% (p=0.000 n=8+10) MaxAggregateFloat1K-6 9.26µs ± 2% 9.94µs ± 1% +7.29% (p=0.000 n=10+10) MaxAggregateDecimal1K-6 113µs ±12% 126µs ±43% ~ (p=0.661 n=9+10) MinAggregateInt1K-6 9.52µs ± 1% 9.23µs ± 1% -3.02% (p=0.000 n=9+10) MinAggregateFloat1K-6 9.12µs ± 1% 8.88µs ± 1% -2.57% (p=0.000 n=8+10) MinAggregateDecimal1K-6 223µs ± 5% 224µs ± 2% ~ (p=0.842 n=10+9) VarianceAggregateInt1K-6 1.49ms ± 1% 1.49ms ± 1% ~ (p=0.274 n=10+8) VarianceAggregateFloat1K-6 14.2µs ± 1% 11.7µs ± 1% -17.42% (p=0.000 n=9+10) VarianceAggregateDecimal1K-6 1.36ms ± 2% 1.31ms ± 2% -4.12% (p=0.000 n=10+10) StddevAggregateInt1K-6 1.69ms ± 1% 1.65ms ± 2% -2.47% (p=0.000 n=10+10) StddevAggregateFloat1K-6 15.7µs ± 1% 11.8µs ± 2% -24.81% (p=0.000 n=9+10) StddevAggregateDecimal1K-6 1.38ms ± 1% 1.32ms ± 5% -4.41% (p=0.000 n=9+10) name old alloc/op new alloc/op delta AvgAggregateInt1K-6 772B ± 0% 692B ± 0% -10.36% (p=0.000 n=10+10) AvgAggregateSmallInt1K-6 756B ± 0% 628B ± 0% -16.93% (p=0.000 n=10+10) AvgAggregateFloat1K-6 128B ± 0% 64B ± 0% -50.00% (p=0.000 n=10+10) AvgAggregateDecimal1K-6 142kB ± 1% 142kB ± 1% ~ (p=0.239 n=10+10) CountAggregate1K-6 16.0B ± 0% 16.0B ± 0% ~ (all samples are equal) SumAggregateInt1K-6 304B ± 0% 192B ± 0% -36.84% (p=0.000 n=10+10) SumAggregateSmallInt1K-6 304B ± 0% 144B ± 0% -52.63% (p=0.000 n=10+10) SumAggregateFloat1K-6 120B ± 0% 24B ± 0% -80.00% (p=0.000 n=10+10) SumAggregateDecimal1K-6 142kB ± 1% 141kB ± 2% ~ (p=0.123 n=10+10) MaxAggregateInt1K-6 16.0B ± 0% 16.0B ± 0% ~ (all samples are equal) MaxAggregateFloat1K-6 16.0B ± 0% 16.0B ± 0% ~ (all samples are equal) MaxAggregateDecimal1K-6 57.7kB ±17% 67.7kB ±58% ~ (p=0.645 n=9+10) MinAggregateInt1K-6 16.0B ± 0% 16.0B ± 0% ~ (all samples are equal) MinAggregateFloat1K-6 16.0B ± 0% 16.0B ± 0% ~ (all samples are equal) MinAggregateDecimal1K-6 138kB ± 5% 140kB ± 2% ~ (p=0.305 n=10+9) VarianceAggregateInt1K-6 773kB ± 0% 773kB ± 0% ~ (p=0.743 n=8+9) VarianceAggregateFloat1K-6 104B ± 0% 40B ± 0% -61.54% (p=0.000 n=10+10) VarianceAggregateDecimal1K-6 709kB ± 2% 706kB ± 1% ~ (p=0.436 n=10+10) StddevAggregateInt1K-6 869kB ± 0% 868kB ± 0% ~ (p=0.128 n=9+10) StddevAggregateFloat1K-6 112B ± 0% 64B ± 0% -42.86% (p=0.000 n=10+10) StddevAggregateDecimal1K-6 716kB ± 1% 723kB ± 2% +1.06% (p=0.028 n=9+10) name old allocs/op new allocs/op delta AvgAggregateInt1K-6 16.0 ± 0% 15.0 ± 0% -6.25% (p=0.000 n=10+10) AvgAggregateSmallInt1K-6 16.0 ± 0% 14.0 ± 0% -12.50% (p=0.000 n=10+10) AvgAggregateFloat1K-6 3.00 ± 0% 4.00 ± 0% +33.33% (p=0.000 n=10+10) AvgAggregateDecimal1K-6 2.97k ± 1% 2.96k ± 1% ~ (p=0.239 n=10+10) CountAggregate1K-6 2.00 ± 0% 2.00 ± 0% ~ (all samples are equal) SumAggregateInt1K-6 5.00 ± 0% 3.00 ± 0% -40.00% (p=0.000 n=10+10) SumAggregateSmallInt1K-6 5.00 ± 0% 2.00 ± 0% -60.00% (p=0.000 n=10+10) SumAggregateFloat1K-6 2.00 ± 0% 2.00 ± 0% ~ (all samples are equal) SumAggregateDecimal1K-6 2.95k ± 1% 2.93k ± 2% ~ (p=0.128 n=10+10) MaxAggregateInt1K-6 1.00 ± 0% 1.00 ± 0% ~ (all samples are equal) MaxAggregateFloat1K-6 1.00 ± 0% 1.00 ± 0% ~ (all samples are equal) MaxAggregateDecimal1K-6 1.20k ±17% 1.41k ±58% ~ (p=0.645 n=9+10) MinAggregateInt1K-6 1.00 ± 0% 1.00 ± 0% ~ (all samples are equal) MinAggregateFloat1K-6 1.00 ± 0% 1.00 ± 0% ~ (all samples are equal) MinAggregateDecimal1K-6 2.88k ± 5% 2.92k ± 2% ~ (p=0.305 n=10+9) VarianceAggregateInt1K-6 17.4k ± 0% 17.4k ± 0% ~ (p=0.799 n=8+9) VarianceAggregateFloat1K-6 3.00 ± 0% 2.00 ± 0% -33.33% (p=0.000 n=10+10) VarianceAggregateDecimal1K-6 15.9k ± 1% 16.0k ± 1% ~ (p=0.905 n=9+10) StddevAggregateInt1K-6 19.1k ± 0% 19.1k ± 0% ~ (p=0.288 n=10+10) StddevAggregateFloat1K-6 4.00 ± 0% 4.00 ± 0% ~ (all samples are equal) StddevAggregateDecimal1K-6 16.1k ± 0% 16.3k ± 1% ~ (p=0.079 n=9+10) ``` --- sql/group.go | 26 +- sql/parser/aggregate_builtins.go | 429 +++++++++++++------------- sql/parser/aggregate_builtins_test.go | 36 +-- sql/parser/expr.go | 6 + sql/testdata/aggregate | 5 + 5 files changed, 260 insertions(+), 242 deletions(-) diff --git a/sql/group.go b/sql/group.go index fc9004cb6328..d527a4b9cc04 100644 --- a/sql/group.go +++ b/sql/group.go @@ -487,22 +487,18 @@ func (v *extractAggregatesVisitor) VisitPre(expr parser.Expr) (recurse bool, new switch t := expr.(type) { case *parser.FuncExpr: - fn, err := t.Name.Normalize() - if err != nil { - v.err = err - return false, expr - } - - if impl, ok := parser.Aggregates[strings.ToLower(fn.Function())]; ok { + if agg := t.GetAggregateConstructor(); agg != nil { if len(t.Exprs) != 1 { // Type checking has already run on these expressions thus // if an aggregate function of the wrong arity gets here, // something has gone really wrong. - panic(fmt.Sprintf("%q has %d arguments (expected 1)", fn, len(t.Exprs))) + panic(fmt.Sprintf("%q has %d arguments (expected 1)", t.Name, len(t.Exprs))) } + argExpr := t.Exprs[0] + defer v.subAggregateVisitor.Reset() - parser.WalkExprConst(&v.subAggregateVisitor, t.Exprs[0]) + parser.WalkExprConst(&v.subAggregateVisitor, argExpr) if v.subAggregateVisitor.Aggregated { v.err = fmt.Errorf("aggregate function calls cannot be nested under %s", t.Name) return false, expr @@ -510,8 +506,8 @@ func (v *extractAggregatesVisitor) VisitPre(expr parser.Expr) (recurse bool, new f := &aggregateFuncHolder{ expr: t, - arg: t.Exprs[0].(parser.TypedExpr), - create: impl[0].AggregateFunc, + arg: argExpr.(parser.TypedExpr), + create: agg, group: v.n, buckets: make(map[string]parser.AggregateFunc), } @@ -603,7 +599,8 @@ func (a *aggregateFuncHolder) add(bucket []byte, d parser.Datum) error { a.buckets[string(bucket)] = impl } - return impl.Add(d) + impl.Add(d) + return nil } func (*aggregateFuncHolder) Variable() {} @@ -632,10 +629,7 @@ func (a *aggregateFuncHolder) Eval(ctx *parser.EvalContext) (parser.Datum, error found = a.create() } - datum, err := found.Result() - if err != nil { - return nil, err - } + datum := found.Result() // This is almost certainly the identity. Oh well. return datum.Eval(ctx) diff --git a/sql/parser/aggregate_builtins.go b/sql/parser/aggregate_builtins.go index a78ba42b1756..4c7531707c9d 100644 --- a/sql/parser/aggregate_builtins.go +++ b/sql/parser/aggregate_builtins.go @@ -15,13 +15,13 @@ package parser import ( + "fmt" "math" "strings" "gopkg.in/inf.v0" "github.com/cockroachdb/cockroach/util/decimal" - "github.com/pkg/errors" ) func init() { @@ -35,8 +35,8 @@ func init() { // AggregateFunc accumulates the result of a some function of a Datum. type AggregateFunc interface { - Add(Datum) error - Result() (Datum, error) + Add(Datum) + Result() Datum } var _ Visitor = &IsAggregateVisitor{} @@ -113,9 +113,9 @@ func (p *Parser) IsAggregate(n *SelectClause) bool { // execution. var Aggregates = map[string][]Builtin{ "avg": { - makeAggBuiltin(TypeInt, TypeDecimal, newAvgAggregate), - makeAggBuiltin(TypeFloat, TypeFloat, newAvgAggregate), - makeAggBuiltin(TypeDecimal, TypeDecimal, newAvgAggregate), + makeAggBuiltin(TypeInt, TypeDecimal, newIntAvgAggregate), + makeAggBuiltin(TypeFloat, TypeFloat, newFloatAvgAggregate), + makeAggBuiltin(TypeDecimal, TypeDecimal, newDecimalAvgAggregate), }, "bool_and": { @@ -132,21 +132,21 @@ var Aggregates = map[string][]Builtin{ "min": makeAggBuiltins(newMinAggregate, TypeBool, TypeInt, TypeFloat, TypeDecimal, TypeString, TypeBytes, TypeDate, TypeTimestamp, TypeInterval), "sum": { - makeAggBuiltin(TypeInt, TypeDecimal, newSumAggregate), - makeAggBuiltin(TypeFloat, TypeFloat, newSumAggregate), - makeAggBuiltin(TypeDecimal, TypeDecimal, newSumAggregate), + makeAggBuiltin(TypeInt, TypeDecimal, newIntSumAggregate), + makeAggBuiltin(TypeFloat, TypeFloat, newFloatSumAggregate), + makeAggBuiltin(TypeDecimal, TypeDecimal, newDecimalSumAggregate), }, "variance": { - makeAggBuiltin(TypeInt, TypeDecimal, newVarianceAggregate), - makeAggBuiltin(TypeDecimal, TypeDecimal, newVarianceAggregate), - makeAggBuiltin(TypeFloat, TypeFloat, newVarianceAggregate), + makeAggBuiltin(TypeInt, TypeDecimal, newIntVarianceAggregate), + makeAggBuiltin(TypeDecimal, TypeDecimal, newDecimalVarianceAggregate), + makeAggBuiltin(TypeFloat, TypeFloat, newFloatVarianceAggregate), }, "stddev": { - makeAggBuiltin(TypeInt, TypeDecimal, newStddevAggregate), - makeAggBuiltin(TypeDecimal, TypeDecimal, newStddevAggregate), - makeAggBuiltin(TypeFloat, TypeFloat, newStddevAggregate), + makeAggBuiltin(TypeInt, TypeDecimal, newIntStddevAggregate), + makeAggBuiltin(TypeDecimal, TypeDecimal, newDecimalStddevAggregate), + makeAggBuiltin(TypeFloat, TypeFloat, newFloatStddevAggregate), }, } @@ -181,9 +181,11 @@ var _ AggregateFunc = &avgAggregate{} var _ AggregateFunc = &countAggregate{} var _ AggregateFunc = &MaxAggregate{} var _ AggregateFunc = &MinAggregate{} -var _ AggregateFunc = &sumAggregate{} +var _ AggregateFunc = &intSumAggregate{} +var _ AggregateFunc = &decimalSumAggregate{} +var _ AggregateFunc = &floatSumAggregate{} var _ AggregateFunc = &stddevAggregate{} -var _ AggregateFunc = &varianceAggregate{} +var _ AggregateFunc = &intVarianceAggregate{} var _ AggregateFunc = &floatVarianceAggregate{} var _ AggregateFunc = &decimalVarianceAggregate{} var _ AggregateFunc = &identAggregate{} @@ -204,120 +206,106 @@ func NewIdentAggregate() AggregateFunc { } // Add sets the value to the passed datum. -func (a *identAggregate) Add(datum Datum) error { +func (a *identAggregate) Add(datum Datum) { a.val = datum - return nil } // Result returns the value most recently passed to Add. -func (a *identAggregate) Result() (Datum, error) { - return a.val, nil +func (a *identAggregate) Result() Datum { + return a.val } type avgAggregate struct { - sumAggregate + agg AggregateFunc count int } -func newAvgAggregate() AggregateFunc { - return &avgAggregate{} +func newIntAvgAggregate() AggregateFunc { + return &avgAggregate{agg: newIntSumAggregate()} +} +func newFloatAvgAggregate() AggregateFunc { + return &avgAggregate{agg: newFloatSumAggregate()} +} +func newDecimalAvgAggregate() AggregateFunc { + return &avgAggregate{agg: newDecimalSumAggregate()} } // Add accumulates the passed datum into the average. -func (a *avgAggregate) Add(datum Datum) error { +func (a *avgAggregate) Add(datum Datum) { if datum == DNull { - return nil - } - if err := a.sumAggregate.Add(datum); err != nil { - return err + return } + a.agg.Add(datum) a.count++ - return nil } // Result returns the average of all datums passed to Add. -func (a *avgAggregate) Result() (Datum, error) { - sum, err := a.sumAggregate.Result() - if err != nil { - return nil, err - } +func (a *avgAggregate) Result() Datum { + sum := a.agg.Result() if sum == DNull { - return sum, nil + return sum } switch t := sum.(type) { case *DFloat: - return NewDFloat(*t / DFloat(a.count)), nil + return NewDFloat(*t / DFloat(a.count)) case *DDecimal: count := inf.NewDec(int64(a.count), 0) t.QuoRound(&t.Dec, count, decimal.Precision, inf.RoundHalfUp) - return t, nil + return t default: - return nil, errors.Errorf("unexpected SUM result type: %s", t.Type()) + panic(fmt.Sprintf("unexpected SUM result type: %s", t.Type())) } } type boolAndAggregate struct { sawNonNull bool - sawFalse bool + result bool } func newBoolAndAggregate() AggregateFunc { return &boolAndAggregate{} } -func (a *boolAndAggregate) Add(datum Datum) error { +func (a *boolAndAggregate) Add(datum Datum) { if datum == DNull { - return nil + return } - a.sawNonNull = true - switch t := datum.(type) { - case *DBool: - if !a.sawFalse { - a.sawFalse = !bool(*t) - } - return nil - default: - return errors.Errorf("unexpected BOOL_AND argument type: %s", t.Type()) + if !a.sawNonNull { + a.sawNonNull = true + a.result = true } + a.result = a.result && bool(*datum.(*DBool)) } -func (a *boolAndAggregate) Result() (Datum, error) { +func (a *boolAndAggregate) Result() Datum { if !a.sawNonNull { - return DNull, nil + return DNull } - return MakeDBool(DBool(!a.sawFalse)), nil + return MakeDBool(DBool(a.result)) } type boolOrAggregate struct { sawNonNull bool - sawTrue bool + result bool } func newBoolOrAggregate() AggregateFunc { return &boolOrAggregate{} } -func (a *boolOrAggregate) Add(datum Datum) error { +func (a *boolOrAggregate) Add(datum Datum) { if datum == DNull { - return nil + return } a.sawNonNull = true - switch t := datum.(type) { - case *DBool: - if !a.sawTrue { - a.sawTrue = bool(*t) - } - return nil - default: - return errors.Errorf("unexpected BOOL_OR argument type: %s", t.Type()) - } + a.result = a.result || bool(*datum.(*DBool)) } -func (a *boolOrAggregate) Result() (Datum, error) { +func (a *boolOrAggregate) Result() Datum { if !a.sawNonNull { - return DNull, nil + return DNull } - return MakeDBool(DBool(a.sawTrue)), nil + return MakeDBool(DBool(a.result)) } type countAggregate struct { @@ -328,26 +316,16 @@ func newCountAggregate() AggregateFunc { return &countAggregate{} } -func (a *countAggregate) Add(datum Datum) error { +func (a *countAggregate) Add(datum Datum) { if datum == DNull { - return nil - } - switch t := datum.(type) { - case *DTuple: - for _, d := range *t { - if d != DNull { - a.count++ - break - } - } - default: - a.count++ + return } - return nil + a.count++ + return } -func (a *countAggregate) Result() (Datum, error) { - return NewDInt(DInt(a.count)), nil +func (a *countAggregate) Result() Datum { + return NewDInt(DInt(a.count)) } // MaxAggregate keeps track of the largest value passed to Add. @@ -360,27 +338,26 @@ func newMaxAggregate() AggregateFunc { } // Add sets the max to the larger of the current max or the passed datum. -func (a *MaxAggregate) Add(datum Datum) error { +func (a *MaxAggregate) Add(datum Datum) { if datum == DNull { - return nil + return } if a.max == nil { a.max = datum - return nil + return } c := a.max.Compare(datum) if c < 0 { a.max = datum } - return nil } // Result returns the largest value passed to Add. -func (a *MaxAggregate) Result() (Datum, error) { +func (a *MaxAggregate) Result() Datum { if a.max == nil { - return DNull, nil + return DNull } - return a.max, nil + return a.max } // MinAggregate keeps track of the smallest value passed to Add. @@ -393,140 +370,162 @@ func newMinAggregate() AggregateFunc { } // Add sets the min to the smaller of the current min or the passed datum. -func (a *MinAggregate) Add(datum Datum) error { +func (a *MinAggregate) Add(datum Datum) { if datum == DNull { - return nil + return } if a.min == nil { a.min = datum - return nil + return } c := a.min.Compare(datum) if c > 0 { a.min = datum } - return nil } // Result returns the smallest value passed to Add. -func (a *MinAggregate) Result() (Datum, error) { +func (a *MinAggregate) Result() Datum { if a.min == nil { - return DNull, nil + return DNull } - return a.min, nil + return a.min } -type sumAggregate struct { - sumType Datum - sumFloat DFloat - sumDec inf.Dec - tmpDec inf.Dec +type intSumAggregate struct { + // Either the `intSum` and `decSum` fields contains the + // result. Which one is used is determined by the `large` field + // below. + intSum int64 + decSum DDecimal + tmpDec inf.Dec + large bool + seenNonNull bool } -func newSumAggregate() AggregateFunc { - return &sumAggregate{} +func newIntSumAggregate() AggregateFunc { + return &intSumAggregate{} } // Add adds the value of the passed datum to the sum. -func (a *sumAggregate) Add(datum Datum) error { +func (a *intSumAggregate) Add(datum Datum) { if datum == DNull { - return nil + return + } + + t := int64(*datum.(*DInt)) + if t != 0 { + // The sum can be computed using a single int64 as long as the + // result of the addition does not overflow. However since Go + // does not provide checked addition, we have to check for the + // overflow explicitly. + if !a.large && + ((t < 0 && a.intSum < math.MinInt64-t) || + (t > 0 && a.intSum > math.MaxInt64-t)) { + // And overflow was detected; go to large integers, but keep the + // sum computed so far. + a.large = true + a.decSum.SetUnscaled(a.intSum) + } + + if a.large { + a.tmpDec.SetUnscaled(t) + a.decSum.Add(&a.decSum.Dec, &a.tmpDec) + } else { + a.intSum += t + } } - switch t := datum.(type) { - case *DFloat: - a.sumFloat += *t - case *DInt: - a.tmpDec.SetUnscaled(int64(*t)) - a.sumDec.Add(&a.sumDec, &a.tmpDec) - case *DDecimal: - a.sumDec.Add(&a.sumDec, &t.Dec) - default: - return errors.Errorf("unexpected SUM argument type: %s", datum.Type()) + a.seenNonNull = true +} + +// Result returns the sum. +func (a *intSumAggregate) Result() Datum { + if !a.seenNonNull { + return DNull } - if a.sumType == nil { - a.sumType = datum + if !a.large { + a.decSum.SetUnscaled(a.intSum) + } + return &a.decSum +} + +type decimalSumAggregate struct { + sum inf.Dec + sawNonNull bool +} + +func newDecimalSumAggregate() AggregateFunc { + return &decimalSumAggregate{} +} + +// Add adds the value of the passed datum to the sum. +func (a *decimalSumAggregate) Add(datum Datum) { + if datum == DNull { + return } - return nil + t := datum.(*DDecimal) + a.sum.Add(&a.sum, &t.Dec) + a.sawNonNull = true } // Result returns the sum. -func (a *sumAggregate) Result() (Datum, error) { - if a.sumType == nil { - return DNull, nil - } - switch { - case a.sumType.TypeEqual(TypeFloat): - return NewDFloat(a.sumFloat), nil - case a.sumType.TypeEqual(TypeInt), a.sumType.TypeEqual(TypeDecimal): - dd := &DDecimal{} - dd.Set(&a.sumDec) - return dd, nil - default: - panic("unreachable") +func (a *decimalSumAggregate) Result() Datum { + if !a.sawNonNull { + return DNull } + dd := &DDecimal{} + dd.Set(&a.sum) + return dd } -type varianceAggregate struct { - typedAggregate AggregateFunc - // Used for passing int64s as *inf.Dec values. - tmpDec DDecimal +type floatSumAggregate struct { + sum float64 + sawNonNull bool } -func newVarianceAggregate() AggregateFunc { - return &varianceAggregate{} +func newFloatSumAggregate() AggregateFunc { + return &floatSumAggregate{} } -func (a *varianceAggregate) Add(datum Datum) error { +// Add adds the value of the passed datum to the sum. +func (a *floatSumAggregate) Add(datum Datum) { if datum == DNull { - return nil + return } + t := datum.(*DFloat) + a.sum += float64(*t) + a.sawNonNull = true +} - const unexpectedErrFormat = "unexpected VARIANCE argument type: %s" - switch t := datum.(type) { - case *DFloat: - if a.typedAggregate == nil { - a.typedAggregate = newFloatVarianceAggregate() - } else { - switch a.typedAggregate.(type) { - case *floatVarianceAggregate: - default: - return errors.Errorf(unexpectedErrFormat, datum.Type()) - } - } - return a.typedAggregate.Add(t) - case *DInt: - if a.typedAggregate == nil { - a.typedAggregate = newDecimalVarianceAggregate() - } else { - switch a.typedAggregate.(type) { - case *decimalVarianceAggregate: - default: - return errors.Errorf(unexpectedErrFormat, datum.Type()) - } - } - a.tmpDec.SetUnscaled(int64(*t)) - return a.typedAggregate.Add(&a.tmpDec) - case *DDecimal: - if a.typedAggregate == nil { - a.typedAggregate = newDecimalVarianceAggregate() - } else { - switch a.typedAggregate.(type) { - case *decimalVarianceAggregate: - default: - return errors.Errorf(unexpectedErrFormat, datum.Type()) - } - } - return a.typedAggregate.Add(t) - default: - return errors.Errorf(unexpectedErrFormat, datum.Type()) +// Result returns the sum. +func (a *floatSumAggregate) Result() Datum { + if !a.sawNonNull { + return DNull } + return NewDFloat(DFloat(a.sum)) } -func (a *varianceAggregate) Result() (Datum, error) { - if a.typedAggregate == nil { - return DNull, nil +type intVarianceAggregate struct { + agg decimalVarianceAggregate + // Used for passing int64s as *inf.Dec values. + tmpDec DDecimal +} + +func newIntVarianceAggregate() AggregateFunc { + return &intVarianceAggregate{} +} + +func (a *intVarianceAggregate) Add(datum Datum) { + if datum == DNull { + return } - return a.typedAggregate.Result() + + a.tmpDec.SetUnscaled(int64(*datum.(*DInt))) + a.agg.Add(&a.tmpDec) +} + +func (a *intVarianceAggregate) Result() Datum { + return a.agg.Result() } type floatVarianceAggregate struct { @@ -539,7 +538,10 @@ func newFloatVarianceAggregate() AggregateFunc { return &floatVarianceAggregate{} } -func (a *floatVarianceAggregate) Add(datum Datum) error { +func (a *floatVarianceAggregate) Add(datum Datum) { + if datum == DNull { + return + } f := float64(*datum.(*DFloat)) // Uses the Knuth/Welford method for accurately computing variance online in a @@ -549,14 +551,13 @@ func (a *floatVarianceAggregate) Add(datum Datum) error { delta := f - a.mean a.mean += delta / float64(a.count) a.sqrDiff += delta * (f - a.mean) - return nil } -func (a *floatVarianceAggregate) Result() (Datum, error) { +func (a *floatVarianceAggregate) Result() Datum { if a.count < 2 { - return DNull, nil + return DNull } - return NewDFloat(DFloat(a.sqrDiff / (float64(a.count) - 1))), nil + return NewDFloat(DFloat(a.sqrDiff / (float64(a.count) - 1))) } type decimalVarianceAggregate struct { @@ -580,50 +581,64 @@ var ( decimalTwo = inf.NewDec(2, 0) ) -func (a *decimalVarianceAggregate) Add(datum Datum) error { - d := datum.(*DDecimal).Dec +func (a *decimalVarianceAggregate) Add(datum Datum) { + if datum == DNull { + return + } + d := &datum.(*DDecimal).Dec // Uses the Knuth/Welford method for accurately computing variance online in a // single pass. See http://www.johndcook.com/blog/standard_deviation/ and // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm. a.count.Add(&a.count, decimalOne) - a.delta.Sub(&d, &a.mean) + a.delta.Sub(d, &a.mean) a.tmp.QuoRound(&a.delta, &a.count, decimal.Precision, inf.RoundHalfUp) a.mean.Add(&a.mean, &a.tmp) - a.tmp.Sub(&d, &a.mean) + a.tmp.Sub(d, &a.mean) a.sqrDiff.Add(&a.sqrDiff, a.delta.Mul(&a.delta, &a.tmp)) - return nil } -func (a *decimalVarianceAggregate) Result() (Datum, error) { +func (a *decimalVarianceAggregate) Result() Datum { if a.count.Cmp(decimalTwo) < 0 { - return DNull, nil + return DNull } a.tmp.Sub(&a.count, decimalOne) dd := &DDecimal{} dd.QuoRound(&a.sqrDiff, &a.tmp, decimal.Precision, inf.RoundHalfUp) - return dd, nil + return dd } type stddevAggregate struct { - varianceAggregate + agg AggregateFunc +} + +func newIntStddevAggregate() AggregateFunc { + return &stddevAggregate{agg: newIntVarianceAggregate()} +} +func newFloatStddevAggregate() AggregateFunc { + return &stddevAggregate{agg: newFloatVarianceAggregate()} +} +func newDecimalStddevAggregate() AggregateFunc { + return &stddevAggregate{agg: newDecimalVarianceAggregate()} } -func newStddevAggregate() AggregateFunc { - return &stddevAggregate{varianceAggregate: *newVarianceAggregate().(*varianceAggregate)} +// Add implements the AggregateFunc interface. +func (a *stddevAggregate) Add(datum Datum) { + a.agg.Add(datum) } -func (a *stddevAggregate) Result() (Datum, error) { - variance, err := a.varianceAggregate.Result() - if err != nil || variance == DNull { - return variance, err +// Result computes the square root of the variance. +func (a *stddevAggregate) Result() Datum { + variance := a.agg.Result() + if variance == DNull { + return variance } switch t := variance.(type) { case *DFloat: - return NewDFloat(DFloat(math.Sqrt(float64(*t)))), nil + return NewDFloat(DFloat(math.Sqrt(float64(*t)))) case *DDecimal: decimal.Sqrt(&t.Dec, &t.Dec, decimal.Precision) - return t, nil + return t } - return nil, errors.Errorf("unexpected variance result type: %s", variance.Type()) + panic(fmt.Sprintf("unexpected variance result type: %s", variance.Type())) } diff --git a/sql/parser/aggregate_builtins_test.go b/sql/parser/aggregate_builtins_test.go index 33f6ac6bcfc0..59e6b01877fd 100644 --- a/sql/parser/aggregate_builtins_test.go +++ b/sql/parser/aggregate_builtins_test.go @@ -77,30 +77,28 @@ func runBenchmarkAggregate(b *testing.B, aggFunc func() AggregateFunc, vals []Da for i := 0; i < b.N; i++ { aggImpl := aggFunc() for i := range vals { - if err := aggImpl.Add(vals[i]); err != nil { - b.Errorf("adding value to aggregate implementation %T failed: %v", aggImpl, err) - } + aggImpl.Add(vals[i]) } - if _, err := aggImpl.Result(); err != nil { - b.Errorf("taking result of aggregate implementation %T failed: %v", aggImpl, err) + if aggImpl.Result() == nil { + b.Errorf("taking result of aggregate implementation %T failed", aggImpl) } } } func BenchmarkAvgAggregateInt1K(b *testing.B) { - runBenchmarkAggregate(b, newAvgAggregate, makeIntTestDatum(1000)) + runBenchmarkAggregate(b, newIntAvgAggregate, makeIntTestDatum(1000)) } func BenchmarkAvgAggregateSmallInt1K(b *testing.B) { - runBenchmarkAggregate(b, newAvgAggregate, makeSmallIntTestDatum(1000)) + runBenchmarkAggregate(b, newIntAvgAggregate, makeSmallIntTestDatum(1000)) } func BenchmarkAvgAggregateFloat1K(b *testing.B) { - runBenchmarkAggregate(b, newAvgAggregate, makeFloatTestDatum(1000)) + runBenchmarkAggregate(b, newFloatAvgAggregate, makeFloatTestDatum(1000)) } func BenchmarkAvgAggregateDecimal1K(b *testing.B) { - runBenchmarkAggregate(b, newAvgAggregate, makeDecimalTestDatum(1000)) + runBenchmarkAggregate(b, newDecimalAvgAggregate, makeDecimalTestDatum(1000)) } func BenchmarkCountAggregate1K(b *testing.B) { @@ -108,19 +106,19 @@ func BenchmarkCountAggregate1K(b *testing.B) { } func BenchmarkSumAggregateInt1K(b *testing.B) { - runBenchmarkAggregate(b, newSumAggregate, makeIntTestDatum(1000)) + runBenchmarkAggregate(b, newIntSumAggregate, makeIntTestDatum(1000)) } func BenchmarkSumAggregateSmallInt1K(b *testing.B) { - runBenchmarkAggregate(b, newSumAggregate, makeSmallIntTestDatum(1000)) + runBenchmarkAggregate(b, newIntSumAggregate, makeSmallIntTestDatum(1000)) } func BenchmarkSumAggregateFloat1K(b *testing.B) { - runBenchmarkAggregate(b, newSumAggregate, makeFloatTestDatum(1000)) + runBenchmarkAggregate(b, newFloatSumAggregate, makeFloatTestDatum(1000)) } func BenchmarkSumAggregateDecimal1K(b *testing.B) { - runBenchmarkAggregate(b, newSumAggregate, makeDecimalTestDatum(1000)) + runBenchmarkAggregate(b, newDecimalSumAggregate, makeDecimalTestDatum(1000)) } func BenchmarkMaxAggregateInt1K(b *testing.B) { @@ -148,25 +146,25 @@ func BenchmarkMinAggregateDecimal1K(b *testing.B) { } func BenchmarkVarianceAggregateInt1K(b *testing.B) { - runBenchmarkAggregate(b, newVarianceAggregate, makeIntTestDatum(1000)) + runBenchmarkAggregate(b, newIntVarianceAggregate, makeIntTestDatum(1000)) } func BenchmarkVarianceAggregateFloat1K(b *testing.B) { - runBenchmarkAggregate(b, newVarianceAggregate, makeFloatTestDatum(1000)) + runBenchmarkAggregate(b, newFloatVarianceAggregate, makeFloatTestDatum(1000)) } func BenchmarkVarianceAggregateDecimal1K(b *testing.B) { - runBenchmarkAggregate(b, newVarianceAggregate, makeDecimalTestDatum(1000)) + runBenchmarkAggregate(b, newDecimalVarianceAggregate, makeDecimalTestDatum(1000)) } func BenchmarkStddevAggregateInt1K(b *testing.B) { - runBenchmarkAggregate(b, newStddevAggregate, makeIntTestDatum(1000)) + runBenchmarkAggregate(b, newIntStddevAggregate, makeIntTestDatum(1000)) } func BenchmarkStddevAggregateFloat1K(b *testing.B) { - runBenchmarkAggregate(b, newStddevAggregate, makeFloatTestDatum(1000)) + runBenchmarkAggregate(b, newFloatStddevAggregate, makeFloatTestDatum(1000)) } func BenchmarkStddevAggregateDecimal1K(b *testing.B) { - runBenchmarkAggregate(b, newStddevAggregate, makeDecimalTestDatum(1000)) + runBenchmarkAggregate(b, newDecimalStddevAggregate, makeDecimalTestDatum(1000)) } diff --git a/sql/parser/expr.go b/sql/parser/expr.go index 1e20187a24c0..ff90961b0a2b 100644 --- a/sql/parser/expr.go +++ b/sql/parser/expr.go @@ -794,6 +794,12 @@ type FuncExpr struct { fn Builtin } +// GetAggregateConstructor exposes the AggregateFunc field for use by +// the group node in package sql. +func (node *FuncExpr) GetAggregateConstructor() func() AggregateFunc { + return node.fn.AggregateFunc +} + type funcType int // FuncExpr.Type diff --git a/sql/testdata/aggregate b/sql/testdata/aggregate index e9f564565a4b..a8cbe209daae 100644 --- a/sql/testdata/aggregate +++ b/sql/testdata/aggregate @@ -385,6 +385,11 @@ SELECT COUNT(k)+COUNT(kv.v) FROM kv ---- 11 +query II +SELECT COUNT(NULL::int), COUNT((NULL, NULL)) +---- +0 1 + query IIII SELECT MIN(k), MAX(k), MIN(v), MAX(v) FROM kv ----