diff --git a/sql/group.go b/sql/group.go index fc9004cb6328..d527a4b9cc04 100644 --- a/sql/group.go +++ b/sql/group.go @@ -487,22 +487,18 @@ func (v *extractAggregatesVisitor) VisitPre(expr parser.Expr) (recurse bool, new switch t := expr.(type) { case *parser.FuncExpr: - fn, err := t.Name.Normalize() - if err != nil { - v.err = err - return false, expr - } - - if impl, ok := parser.Aggregates[strings.ToLower(fn.Function())]; ok { + if agg := t.GetAggregateConstructor(); agg != nil { if len(t.Exprs) != 1 { // Type checking has already run on these expressions thus // if an aggregate function of the wrong arity gets here, // something has gone really wrong. - panic(fmt.Sprintf("%q has %d arguments (expected 1)", fn, len(t.Exprs))) + panic(fmt.Sprintf("%q has %d arguments (expected 1)", t.Name, len(t.Exprs))) } + argExpr := t.Exprs[0] + defer v.subAggregateVisitor.Reset() - parser.WalkExprConst(&v.subAggregateVisitor, t.Exprs[0]) + parser.WalkExprConst(&v.subAggregateVisitor, argExpr) if v.subAggregateVisitor.Aggregated { v.err = fmt.Errorf("aggregate function calls cannot be nested under %s", t.Name) return false, expr @@ -510,8 +506,8 @@ func (v *extractAggregatesVisitor) VisitPre(expr parser.Expr) (recurse bool, new f := &aggregateFuncHolder{ expr: t, - arg: t.Exprs[0].(parser.TypedExpr), - create: impl[0].AggregateFunc, + arg: argExpr.(parser.TypedExpr), + create: agg, group: v.n, buckets: make(map[string]parser.AggregateFunc), } @@ -603,7 +599,8 @@ func (a *aggregateFuncHolder) add(bucket []byte, d parser.Datum) error { a.buckets[string(bucket)] = impl } - return impl.Add(d) + impl.Add(d) + return nil } func (*aggregateFuncHolder) Variable() {} @@ -632,10 +629,7 @@ func (a *aggregateFuncHolder) Eval(ctx *parser.EvalContext) (parser.Datum, error found = a.create() } - datum, err := found.Result() - if err != nil { - return nil, err - } + datum := found.Result() // This is almost certainly the identity. Oh well. return datum.Eval(ctx) diff --git a/sql/parser/aggregate_builtins.go b/sql/parser/aggregate_builtins.go index a78ba42b1756..4c7531707c9d 100644 --- a/sql/parser/aggregate_builtins.go +++ b/sql/parser/aggregate_builtins.go @@ -15,13 +15,13 @@ package parser import ( + "fmt" "math" "strings" "gopkg.in/inf.v0" "github.com/cockroachdb/cockroach/util/decimal" - "github.com/pkg/errors" ) func init() { @@ -35,8 +35,8 @@ func init() { // AggregateFunc accumulates the result of a some function of a Datum. type AggregateFunc interface { - Add(Datum) error - Result() (Datum, error) + Add(Datum) + Result() Datum } var _ Visitor = &IsAggregateVisitor{} @@ -113,9 +113,9 @@ func (p *Parser) IsAggregate(n *SelectClause) bool { // execution. var Aggregates = map[string][]Builtin{ "avg": { - makeAggBuiltin(TypeInt, TypeDecimal, newAvgAggregate), - makeAggBuiltin(TypeFloat, TypeFloat, newAvgAggregate), - makeAggBuiltin(TypeDecimal, TypeDecimal, newAvgAggregate), + makeAggBuiltin(TypeInt, TypeDecimal, newIntAvgAggregate), + makeAggBuiltin(TypeFloat, TypeFloat, newFloatAvgAggregate), + makeAggBuiltin(TypeDecimal, TypeDecimal, newDecimalAvgAggregate), }, "bool_and": { @@ -132,21 +132,21 @@ var Aggregates = map[string][]Builtin{ "min": makeAggBuiltins(newMinAggregate, TypeBool, TypeInt, TypeFloat, TypeDecimal, TypeString, TypeBytes, TypeDate, TypeTimestamp, TypeInterval), "sum": { - makeAggBuiltin(TypeInt, TypeDecimal, newSumAggregate), - makeAggBuiltin(TypeFloat, TypeFloat, newSumAggregate), - makeAggBuiltin(TypeDecimal, TypeDecimal, newSumAggregate), + makeAggBuiltin(TypeInt, TypeDecimal, newIntSumAggregate), + makeAggBuiltin(TypeFloat, TypeFloat, newFloatSumAggregate), + makeAggBuiltin(TypeDecimal, TypeDecimal, newDecimalSumAggregate), }, "variance": { - makeAggBuiltin(TypeInt, TypeDecimal, newVarianceAggregate), - makeAggBuiltin(TypeDecimal, TypeDecimal, newVarianceAggregate), - makeAggBuiltin(TypeFloat, TypeFloat, newVarianceAggregate), + makeAggBuiltin(TypeInt, TypeDecimal, newIntVarianceAggregate), + makeAggBuiltin(TypeDecimal, TypeDecimal, newDecimalVarianceAggregate), + makeAggBuiltin(TypeFloat, TypeFloat, newFloatVarianceAggregate), }, "stddev": { - makeAggBuiltin(TypeInt, TypeDecimal, newStddevAggregate), - makeAggBuiltin(TypeDecimal, TypeDecimal, newStddevAggregate), - makeAggBuiltin(TypeFloat, TypeFloat, newStddevAggregate), + makeAggBuiltin(TypeInt, TypeDecimal, newIntStddevAggregate), + makeAggBuiltin(TypeDecimal, TypeDecimal, newDecimalStddevAggregate), + makeAggBuiltin(TypeFloat, TypeFloat, newFloatStddevAggregate), }, } @@ -181,9 +181,11 @@ var _ AggregateFunc = &avgAggregate{} var _ AggregateFunc = &countAggregate{} var _ AggregateFunc = &MaxAggregate{} var _ AggregateFunc = &MinAggregate{} -var _ AggregateFunc = &sumAggregate{} +var _ AggregateFunc = &intSumAggregate{} +var _ AggregateFunc = &decimalSumAggregate{} +var _ AggregateFunc = &floatSumAggregate{} var _ AggregateFunc = &stddevAggregate{} -var _ AggregateFunc = &varianceAggregate{} +var _ AggregateFunc = &intVarianceAggregate{} var _ AggregateFunc = &floatVarianceAggregate{} var _ AggregateFunc = &decimalVarianceAggregate{} var _ AggregateFunc = &identAggregate{} @@ -204,120 +206,106 @@ func NewIdentAggregate() AggregateFunc { } // Add sets the value to the passed datum. -func (a *identAggregate) Add(datum Datum) error { +func (a *identAggregate) Add(datum Datum) { a.val = datum - return nil } // Result returns the value most recently passed to Add. -func (a *identAggregate) Result() (Datum, error) { - return a.val, nil +func (a *identAggregate) Result() Datum { + return a.val } type avgAggregate struct { - sumAggregate + agg AggregateFunc count int } -func newAvgAggregate() AggregateFunc { - return &avgAggregate{} +func newIntAvgAggregate() AggregateFunc { + return &avgAggregate{agg: newIntSumAggregate()} +} +func newFloatAvgAggregate() AggregateFunc { + return &avgAggregate{agg: newFloatSumAggregate()} +} +func newDecimalAvgAggregate() AggregateFunc { + return &avgAggregate{agg: newDecimalSumAggregate()} } // Add accumulates the passed datum into the average. -func (a *avgAggregate) Add(datum Datum) error { +func (a *avgAggregate) Add(datum Datum) { if datum == DNull { - return nil - } - if err := a.sumAggregate.Add(datum); err != nil { - return err + return } + a.agg.Add(datum) a.count++ - return nil } // Result returns the average of all datums passed to Add. -func (a *avgAggregate) Result() (Datum, error) { - sum, err := a.sumAggregate.Result() - if err != nil { - return nil, err - } +func (a *avgAggregate) Result() Datum { + sum := a.agg.Result() if sum == DNull { - return sum, nil + return sum } switch t := sum.(type) { case *DFloat: - return NewDFloat(*t / DFloat(a.count)), nil + return NewDFloat(*t / DFloat(a.count)) case *DDecimal: count := inf.NewDec(int64(a.count), 0) t.QuoRound(&t.Dec, count, decimal.Precision, inf.RoundHalfUp) - return t, nil + return t default: - return nil, errors.Errorf("unexpected SUM result type: %s", t.Type()) + panic(fmt.Sprintf("unexpected SUM result type: %s", t.Type())) } } type boolAndAggregate struct { sawNonNull bool - sawFalse bool + result bool } func newBoolAndAggregate() AggregateFunc { return &boolAndAggregate{} } -func (a *boolAndAggregate) Add(datum Datum) error { +func (a *boolAndAggregate) Add(datum Datum) { if datum == DNull { - return nil + return } - a.sawNonNull = true - switch t := datum.(type) { - case *DBool: - if !a.sawFalse { - a.sawFalse = !bool(*t) - } - return nil - default: - return errors.Errorf("unexpected BOOL_AND argument type: %s", t.Type()) + if !a.sawNonNull { + a.sawNonNull = true + a.result = true } + a.result = a.result && bool(*datum.(*DBool)) } -func (a *boolAndAggregate) Result() (Datum, error) { +func (a *boolAndAggregate) Result() Datum { if !a.sawNonNull { - return DNull, nil + return DNull } - return MakeDBool(DBool(!a.sawFalse)), nil + return MakeDBool(DBool(a.result)) } type boolOrAggregate struct { sawNonNull bool - sawTrue bool + result bool } func newBoolOrAggregate() AggregateFunc { return &boolOrAggregate{} } -func (a *boolOrAggregate) Add(datum Datum) error { +func (a *boolOrAggregate) Add(datum Datum) { if datum == DNull { - return nil + return } a.sawNonNull = true - switch t := datum.(type) { - case *DBool: - if !a.sawTrue { - a.sawTrue = bool(*t) - } - return nil - default: - return errors.Errorf("unexpected BOOL_OR argument type: %s", t.Type()) - } + a.result = a.result || bool(*datum.(*DBool)) } -func (a *boolOrAggregate) Result() (Datum, error) { +func (a *boolOrAggregate) Result() Datum { if !a.sawNonNull { - return DNull, nil + return DNull } - return MakeDBool(DBool(a.sawTrue)), nil + return MakeDBool(DBool(a.result)) } type countAggregate struct { @@ -328,26 +316,16 @@ func newCountAggregate() AggregateFunc { return &countAggregate{} } -func (a *countAggregate) Add(datum Datum) error { +func (a *countAggregate) Add(datum Datum) { if datum == DNull { - return nil - } - switch t := datum.(type) { - case *DTuple: - for _, d := range *t { - if d != DNull { - a.count++ - break - } - } - default: - a.count++ + return } - return nil + a.count++ + return } -func (a *countAggregate) Result() (Datum, error) { - return NewDInt(DInt(a.count)), nil +func (a *countAggregate) Result() Datum { + return NewDInt(DInt(a.count)) } // MaxAggregate keeps track of the largest value passed to Add. @@ -360,27 +338,26 @@ func newMaxAggregate() AggregateFunc { } // Add sets the max to the larger of the current max or the passed datum. -func (a *MaxAggregate) Add(datum Datum) error { +func (a *MaxAggregate) Add(datum Datum) { if datum == DNull { - return nil + return } if a.max == nil { a.max = datum - return nil + return } c := a.max.Compare(datum) if c < 0 { a.max = datum } - return nil } // Result returns the largest value passed to Add. -func (a *MaxAggregate) Result() (Datum, error) { +func (a *MaxAggregate) Result() Datum { if a.max == nil { - return DNull, nil + return DNull } - return a.max, nil + return a.max } // MinAggregate keeps track of the smallest value passed to Add. @@ -393,140 +370,162 @@ func newMinAggregate() AggregateFunc { } // Add sets the min to the smaller of the current min or the passed datum. -func (a *MinAggregate) Add(datum Datum) error { +func (a *MinAggregate) Add(datum Datum) { if datum == DNull { - return nil + return } if a.min == nil { a.min = datum - return nil + return } c := a.min.Compare(datum) if c > 0 { a.min = datum } - return nil } // Result returns the smallest value passed to Add. -func (a *MinAggregate) Result() (Datum, error) { +func (a *MinAggregate) Result() Datum { if a.min == nil { - return DNull, nil + return DNull } - return a.min, nil + return a.min } -type sumAggregate struct { - sumType Datum - sumFloat DFloat - sumDec inf.Dec - tmpDec inf.Dec +type intSumAggregate struct { + // Either the `intSum` and `decSum` fields contains the + // result. Which one is used is determined by the `large` field + // below. + intSum int64 + decSum DDecimal + tmpDec inf.Dec + large bool + seenNonNull bool } -func newSumAggregate() AggregateFunc { - return &sumAggregate{} +func newIntSumAggregate() AggregateFunc { + return &intSumAggregate{} } // Add adds the value of the passed datum to the sum. -func (a *sumAggregate) Add(datum Datum) error { +func (a *intSumAggregate) Add(datum Datum) { if datum == DNull { - return nil + return + } + + t := int64(*datum.(*DInt)) + if t != 0 { + // The sum can be computed using a single int64 as long as the + // result of the addition does not overflow. However since Go + // does not provide checked addition, we have to check for the + // overflow explicitly. + if !a.large && + ((t < 0 && a.intSum < math.MinInt64-t) || + (t > 0 && a.intSum > math.MaxInt64-t)) { + // And overflow was detected; go to large integers, but keep the + // sum computed so far. + a.large = true + a.decSum.SetUnscaled(a.intSum) + } + + if a.large { + a.tmpDec.SetUnscaled(t) + a.decSum.Add(&a.decSum.Dec, &a.tmpDec) + } else { + a.intSum += t + } } - switch t := datum.(type) { - case *DFloat: - a.sumFloat += *t - case *DInt: - a.tmpDec.SetUnscaled(int64(*t)) - a.sumDec.Add(&a.sumDec, &a.tmpDec) - case *DDecimal: - a.sumDec.Add(&a.sumDec, &t.Dec) - default: - return errors.Errorf("unexpected SUM argument type: %s", datum.Type()) + a.seenNonNull = true +} + +// Result returns the sum. +func (a *intSumAggregate) Result() Datum { + if !a.seenNonNull { + return DNull } - if a.sumType == nil { - a.sumType = datum + if !a.large { + a.decSum.SetUnscaled(a.intSum) + } + return &a.decSum +} + +type decimalSumAggregate struct { + sum inf.Dec + sawNonNull bool +} + +func newDecimalSumAggregate() AggregateFunc { + return &decimalSumAggregate{} +} + +// Add adds the value of the passed datum to the sum. +func (a *decimalSumAggregate) Add(datum Datum) { + if datum == DNull { + return } - return nil + t := datum.(*DDecimal) + a.sum.Add(&a.sum, &t.Dec) + a.sawNonNull = true } // Result returns the sum. -func (a *sumAggregate) Result() (Datum, error) { - if a.sumType == nil { - return DNull, nil - } - switch { - case a.sumType.TypeEqual(TypeFloat): - return NewDFloat(a.sumFloat), nil - case a.sumType.TypeEqual(TypeInt), a.sumType.TypeEqual(TypeDecimal): - dd := &DDecimal{} - dd.Set(&a.sumDec) - return dd, nil - default: - panic("unreachable") +func (a *decimalSumAggregate) Result() Datum { + if !a.sawNonNull { + return DNull } + dd := &DDecimal{} + dd.Set(&a.sum) + return dd } -type varianceAggregate struct { - typedAggregate AggregateFunc - // Used for passing int64s as *inf.Dec values. - tmpDec DDecimal +type floatSumAggregate struct { + sum float64 + sawNonNull bool } -func newVarianceAggregate() AggregateFunc { - return &varianceAggregate{} +func newFloatSumAggregate() AggregateFunc { + return &floatSumAggregate{} } -func (a *varianceAggregate) Add(datum Datum) error { +// Add adds the value of the passed datum to the sum. +func (a *floatSumAggregate) Add(datum Datum) { if datum == DNull { - return nil + return } + t := datum.(*DFloat) + a.sum += float64(*t) + a.sawNonNull = true +} - const unexpectedErrFormat = "unexpected VARIANCE argument type: %s" - switch t := datum.(type) { - case *DFloat: - if a.typedAggregate == nil { - a.typedAggregate = newFloatVarianceAggregate() - } else { - switch a.typedAggregate.(type) { - case *floatVarianceAggregate: - default: - return errors.Errorf(unexpectedErrFormat, datum.Type()) - } - } - return a.typedAggregate.Add(t) - case *DInt: - if a.typedAggregate == nil { - a.typedAggregate = newDecimalVarianceAggregate() - } else { - switch a.typedAggregate.(type) { - case *decimalVarianceAggregate: - default: - return errors.Errorf(unexpectedErrFormat, datum.Type()) - } - } - a.tmpDec.SetUnscaled(int64(*t)) - return a.typedAggregate.Add(&a.tmpDec) - case *DDecimal: - if a.typedAggregate == nil { - a.typedAggregate = newDecimalVarianceAggregate() - } else { - switch a.typedAggregate.(type) { - case *decimalVarianceAggregate: - default: - return errors.Errorf(unexpectedErrFormat, datum.Type()) - } - } - return a.typedAggregate.Add(t) - default: - return errors.Errorf(unexpectedErrFormat, datum.Type()) +// Result returns the sum. +func (a *floatSumAggregate) Result() Datum { + if !a.sawNonNull { + return DNull } + return NewDFloat(DFloat(a.sum)) } -func (a *varianceAggregate) Result() (Datum, error) { - if a.typedAggregate == nil { - return DNull, nil +type intVarianceAggregate struct { + agg decimalVarianceAggregate + // Used for passing int64s as *inf.Dec values. + tmpDec DDecimal +} + +func newIntVarianceAggregate() AggregateFunc { + return &intVarianceAggregate{} +} + +func (a *intVarianceAggregate) Add(datum Datum) { + if datum == DNull { + return } - return a.typedAggregate.Result() + + a.tmpDec.SetUnscaled(int64(*datum.(*DInt))) + a.agg.Add(&a.tmpDec) +} + +func (a *intVarianceAggregate) Result() Datum { + return a.agg.Result() } type floatVarianceAggregate struct { @@ -539,7 +538,10 @@ func newFloatVarianceAggregate() AggregateFunc { return &floatVarianceAggregate{} } -func (a *floatVarianceAggregate) Add(datum Datum) error { +func (a *floatVarianceAggregate) Add(datum Datum) { + if datum == DNull { + return + } f := float64(*datum.(*DFloat)) // Uses the Knuth/Welford method for accurately computing variance online in a @@ -549,14 +551,13 @@ func (a *floatVarianceAggregate) Add(datum Datum) error { delta := f - a.mean a.mean += delta / float64(a.count) a.sqrDiff += delta * (f - a.mean) - return nil } -func (a *floatVarianceAggregate) Result() (Datum, error) { +func (a *floatVarianceAggregate) Result() Datum { if a.count < 2 { - return DNull, nil + return DNull } - return NewDFloat(DFloat(a.sqrDiff / (float64(a.count) - 1))), nil + return NewDFloat(DFloat(a.sqrDiff / (float64(a.count) - 1))) } type decimalVarianceAggregate struct { @@ -580,50 +581,64 @@ var ( decimalTwo = inf.NewDec(2, 0) ) -func (a *decimalVarianceAggregate) Add(datum Datum) error { - d := datum.(*DDecimal).Dec +func (a *decimalVarianceAggregate) Add(datum Datum) { + if datum == DNull { + return + } + d := &datum.(*DDecimal).Dec // Uses the Knuth/Welford method for accurately computing variance online in a // single pass. See http://www.johndcook.com/blog/standard_deviation/ and // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm. a.count.Add(&a.count, decimalOne) - a.delta.Sub(&d, &a.mean) + a.delta.Sub(d, &a.mean) a.tmp.QuoRound(&a.delta, &a.count, decimal.Precision, inf.RoundHalfUp) a.mean.Add(&a.mean, &a.tmp) - a.tmp.Sub(&d, &a.mean) + a.tmp.Sub(d, &a.mean) a.sqrDiff.Add(&a.sqrDiff, a.delta.Mul(&a.delta, &a.tmp)) - return nil } -func (a *decimalVarianceAggregate) Result() (Datum, error) { +func (a *decimalVarianceAggregate) Result() Datum { if a.count.Cmp(decimalTwo) < 0 { - return DNull, nil + return DNull } a.tmp.Sub(&a.count, decimalOne) dd := &DDecimal{} dd.QuoRound(&a.sqrDiff, &a.tmp, decimal.Precision, inf.RoundHalfUp) - return dd, nil + return dd } type stddevAggregate struct { - varianceAggregate + agg AggregateFunc +} + +func newIntStddevAggregate() AggregateFunc { + return &stddevAggregate{agg: newIntVarianceAggregate()} +} +func newFloatStddevAggregate() AggregateFunc { + return &stddevAggregate{agg: newFloatVarianceAggregate()} +} +func newDecimalStddevAggregate() AggregateFunc { + return &stddevAggregate{agg: newDecimalVarianceAggregate()} } -func newStddevAggregate() AggregateFunc { - return &stddevAggregate{varianceAggregate: *newVarianceAggregate().(*varianceAggregate)} +// Add implements the AggregateFunc interface. +func (a *stddevAggregate) Add(datum Datum) { + a.agg.Add(datum) } -func (a *stddevAggregate) Result() (Datum, error) { - variance, err := a.varianceAggregate.Result() - if err != nil || variance == DNull { - return variance, err +// Result computes the square root of the variance. +func (a *stddevAggregate) Result() Datum { + variance := a.agg.Result() + if variance == DNull { + return variance } switch t := variance.(type) { case *DFloat: - return NewDFloat(DFloat(math.Sqrt(float64(*t)))), nil + return NewDFloat(DFloat(math.Sqrt(float64(*t)))) case *DDecimal: decimal.Sqrt(&t.Dec, &t.Dec, decimal.Precision) - return t, nil + return t } - return nil, errors.Errorf("unexpected variance result type: %s", variance.Type()) + panic(fmt.Sprintf("unexpected variance result type: %s", variance.Type())) } diff --git a/sql/parser/aggregate_builtins_test.go b/sql/parser/aggregate_builtins_test.go index 33f6ac6bcfc0..59e6b01877fd 100644 --- a/sql/parser/aggregate_builtins_test.go +++ b/sql/parser/aggregate_builtins_test.go @@ -77,30 +77,28 @@ func runBenchmarkAggregate(b *testing.B, aggFunc func() AggregateFunc, vals []Da for i := 0; i < b.N; i++ { aggImpl := aggFunc() for i := range vals { - if err := aggImpl.Add(vals[i]); err != nil { - b.Errorf("adding value to aggregate implementation %T failed: %v", aggImpl, err) - } + aggImpl.Add(vals[i]) } - if _, err := aggImpl.Result(); err != nil { - b.Errorf("taking result of aggregate implementation %T failed: %v", aggImpl, err) + if aggImpl.Result() == nil { + b.Errorf("taking result of aggregate implementation %T failed", aggImpl) } } } func BenchmarkAvgAggregateInt1K(b *testing.B) { - runBenchmarkAggregate(b, newAvgAggregate, makeIntTestDatum(1000)) + runBenchmarkAggregate(b, newIntAvgAggregate, makeIntTestDatum(1000)) } func BenchmarkAvgAggregateSmallInt1K(b *testing.B) { - runBenchmarkAggregate(b, newAvgAggregate, makeSmallIntTestDatum(1000)) + runBenchmarkAggregate(b, newIntAvgAggregate, makeSmallIntTestDatum(1000)) } func BenchmarkAvgAggregateFloat1K(b *testing.B) { - runBenchmarkAggregate(b, newAvgAggregate, makeFloatTestDatum(1000)) + runBenchmarkAggregate(b, newFloatAvgAggregate, makeFloatTestDatum(1000)) } func BenchmarkAvgAggregateDecimal1K(b *testing.B) { - runBenchmarkAggregate(b, newAvgAggregate, makeDecimalTestDatum(1000)) + runBenchmarkAggregate(b, newDecimalAvgAggregate, makeDecimalTestDatum(1000)) } func BenchmarkCountAggregate1K(b *testing.B) { @@ -108,19 +106,19 @@ func BenchmarkCountAggregate1K(b *testing.B) { } func BenchmarkSumAggregateInt1K(b *testing.B) { - runBenchmarkAggregate(b, newSumAggregate, makeIntTestDatum(1000)) + runBenchmarkAggregate(b, newIntSumAggregate, makeIntTestDatum(1000)) } func BenchmarkSumAggregateSmallInt1K(b *testing.B) { - runBenchmarkAggregate(b, newSumAggregate, makeSmallIntTestDatum(1000)) + runBenchmarkAggregate(b, newIntSumAggregate, makeSmallIntTestDatum(1000)) } func BenchmarkSumAggregateFloat1K(b *testing.B) { - runBenchmarkAggregate(b, newSumAggregate, makeFloatTestDatum(1000)) + runBenchmarkAggregate(b, newFloatSumAggregate, makeFloatTestDatum(1000)) } func BenchmarkSumAggregateDecimal1K(b *testing.B) { - runBenchmarkAggregate(b, newSumAggregate, makeDecimalTestDatum(1000)) + runBenchmarkAggregate(b, newDecimalSumAggregate, makeDecimalTestDatum(1000)) } func BenchmarkMaxAggregateInt1K(b *testing.B) { @@ -148,25 +146,25 @@ func BenchmarkMinAggregateDecimal1K(b *testing.B) { } func BenchmarkVarianceAggregateInt1K(b *testing.B) { - runBenchmarkAggregate(b, newVarianceAggregate, makeIntTestDatum(1000)) + runBenchmarkAggregate(b, newIntVarianceAggregate, makeIntTestDatum(1000)) } func BenchmarkVarianceAggregateFloat1K(b *testing.B) { - runBenchmarkAggregate(b, newVarianceAggregate, makeFloatTestDatum(1000)) + runBenchmarkAggregate(b, newFloatVarianceAggregate, makeFloatTestDatum(1000)) } func BenchmarkVarianceAggregateDecimal1K(b *testing.B) { - runBenchmarkAggregate(b, newVarianceAggregate, makeDecimalTestDatum(1000)) + runBenchmarkAggregate(b, newDecimalVarianceAggregate, makeDecimalTestDatum(1000)) } func BenchmarkStddevAggregateInt1K(b *testing.B) { - runBenchmarkAggregate(b, newStddevAggregate, makeIntTestDatum(1000)) + runBenchmarkAggregate(b, newIntStddevAggregate, makeIntTestDatum(1000)) } func BenchmarkStddevAggregateFloat1K(b *testing.B) { - runBenchmarkAggregate(b, newStddevAggregate, makeFloatTestDatum(1000)) + runBenchmarkAggregate(b, newFloatStddevAggregate, makeFloatTestDatum(1000)) } func BenchmarkStddevAggregateDecimal1K(b *testing.B) { - runBenchmarkAggregate(b, newStddevAggregate, makeDecimalTestDatum(1000)) + runBenchmarkAggregate(b, newDecimalStddevAggregate, makeDecimalTestDatum(1000)) } diff --git a/sql/parser/expr.go b/sql/parser/expr.go index 1e20187a24c0..ff90961b0a2b 100644 --- a/sql/parser/expr.go +++ b/sql/parser/expr.go @@ -794,6 +794,12 @@ type FuncExpr struct { fn Builtin } +// GetAggregateConstructor exposes the AggregateFunc field for use by +// the group node in package sql. +func (node *FuncExpr) GetAggregateConstructor() func() AggregateFunc { + return node.fn.AggregateFunc +} + type funcType int // FuncExpr.Type diff --git a/sql/testdata/aggregate b/sql/testdata/aggregate index e9f564565a4b..a8cbe209daae 100644 --- a/sql/testdata/aggregate +++ b/sql/testdata/aggregate @@ -385,6 +385,11 @@ SELECT COUNT(k)+COUNT(kv.v) FROM kv ---- 11 +query II +SELECT COUNT(NULL::int), COUNT((NULL, NULL)) +---- +0 1 + query IIII SELECT MIN(k), MAX(k), MIN(v), MAX(v) FROM kv ----