Skip to content

[SPARK-9673][SQL] Sample standard deviation aggregation function #8058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def _():
'sum': 'Aggregate function: returns the sum of all values in the expression.',
'avg': 'Aggregate function: returns the average of the values in a group.',
'mean': 'Aggregate function: returns the average of the values in a group.',
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
'stddev': 'Aggregate function: returns the sample standard deviation in a group.',
'stddevSamp': 'Aggregate function: returns the sample standard deviation in a group.',
'stddevPop': 'Aggregate function: returns the population standard deviation in a group.',
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.'
}

_functions_1_4 = {
Expand Down
42 changes: 42 additions & 0 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,48 @@ def min(self, *cols):
[Row(min(age)=2, min(height)=80)]
"""

@df_varargs_api
@since(1.5)
def stddev(self, *cols):
"""Computes the sample standard deviation for each numeric column for each group.
Alias for stddevSamp.

:param cols: list of column names (string). Non-numeric columns are ignored.

>>> df.groupBy().stddev('age').collect()
[Row(stddev_samp(age)=2.12...)]
>>> df3.groupBy().stddev('age', 'height').collect()
[Row(stddev_samp(age)=2.12..., stddev_samp(height)=3.53...)]
"""

@df_varargs_api
@since(1.5)
def stddevPop(self, *cols):
"""Computes the sample standard deviation for each numeric column for each group.
Alias for stddevSamp.

:param cols: list of column names (string). Non-numeric columns are ignored.

>>> df.groupBy().stddevPop('age').collect()
[Row(min(age)=1.06...)]
>>> df3.groupBy().stddevPop('age', 'height').collect()
[Row(min(age)=1.06..., min(height)=1.76...)]
"""

@df_varargs_api
@since(1.5)
def stddevSamp(self, *cols):
"""Computes the sample standard deviation for each numeric column for each group.
Alias for stddevSamp.

:param cols: list of column names (string). Non-numeric columns are ignored.

>>> df.groupBy().stddevSamp('age').collect()
[Row(stddev_samp(age)=2.12...)]
>>> df3.groupBy().stddevSamp('age', 'height').collect()
[Row(stddev_samp(age)=2.12..., stddev_samp(height)=3.53...)]
"""

@df_varargs_api
@since(1.3)
def sum(self, *cols):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.expressions.aggregate.{StandardDeviation, Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -524,23 +524,60 @@ class Analyzer(
q transformExpressions {
case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
registry.lookupFunction(name, children) match {
// We get an aggregate function built based on AggregateFunction2 interface.
// So, we wrap it in AggregateExpression2.
case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct)
// Currently, our old aggregate function interface supports SUM(DISTINCT ...)
// and COUTN(DISTINCT ...).
case sumDistinct: SumDistinct => sumDistinct
case countDistinct: CountDistinct => countDistinct
// DISTINCT is not meaningful with Max and Min.
case max: Max if isDistinct => max
case min: Min if isDistinct => min
// For other aggregate functions, DISTINCT keyword is not supported for now.
// Once we converted to the new code path, we will allow using DISTINCT keyword.
case other: AggregateExpression1 if isDistinct =>
failAnalysis(s"$name does not support DISTINCT keyword.")
// If it does not have DISTINCT keyword, we will return it as is.
case other => other

// TODO: This is a hack. Hive uses stddev and std as aliases of stddev_pop, which
// is different from other widely used systems (these systems use these two function
// names as aliases of stddev_samp). So, we explicitly rename it to stddev_samp.
// Once we remove AggregateExpression1, we can remove this hack. Also, because
// we do not have stddev in SimpleFunctionRegistry (we want to resolve
// it to the HiveGenericUDAF based on AggregateExpression1 and then do the
// conversion to AggregateExpression2), if it does not exist in function registry,
// we create StandardDeviation directly.
name.toLowerCase match {
case "std" | "stddev" | "stddev_samp" =>
if (children.length != 1) {
failAnalysis(s"$name requires exactly one argument.")
}
val funcInRegistry =
registry
.lookupFunction("stddev_samp")
.map(_ => registry.lookupFunction("stddev_samp", children))
funcInRegistry.getOrElse {
AggregateExpression2(
StandardDeviation(children.head, sample = true), Complete, isDistinct)
}
case "stddev_pop" =>
if (children.length != 1) {
failAnalysis(s"$name requires exactly one argument.")
}
val funcInRegistry =
registry
.lookupFunction("stddev_pop")
.map(_ => registry.lookupFunction("stddev_pop", children))
funcInRegistry.getOrElse {
AggregateExpression2(
StandardDeviation(children.head, sample = true), Complete, isDistinct)
}
case _ =>
registry.lookupFunction(name, children) match {
// We get an aggregate function built based on AggregateFunction2 interface.
// So, we wrap it in AggregateExpression2.
case agg2: AggregateFunction2 =>
AggregateExpression2(agg2, Complete, isDistinct)
// Currently, our old aggregate function interface supports SUM(DISTINCT ...)
// and COUTN(DISTINCT ...).
case sumDistinct: SumDistinct => sumDistinct
case countDistinct: CountDistinct => countDistinct
// DISTINCT is not meaningful with Max and Min.
case max: Max if isDistinct => max
case min: Min if isDistinct => min
// For other aggregate functions, DISTINCT keyword is not supported for now.
// Once we converted to the new code path, we will allow using DISTINCT keyword.
case other: AggregateExpression1 if isDistinct =>
failAnalysis(s"$name does not support DISTINCT keyword.")
// If it does not have DISTINCT keyword, we will return it as is.
case other => other
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,116 @@ case class Sum(child: Expression) extends AlgebraicAggregate {

override val evaluateExpression = Cast(currentSum, resultType)
}

/**
* Calculates the Standard Deviation using the online formula here:
* https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
* If sample is true, then we will return the unbiased standard deviation.
*/
case class StandardDeviation(child: Expression, sample: Boolean) extends AlgebraicAggregate {

override def children: Seq[Expression] = child :: Nil

override def nullable: Boolean = true

// Return data type.
override def dataType: DataType = resultType

// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))

private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
case _ => DoubleType
}

private lazy val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case _ => DoubleType
}

private lazy val currentCount = AttributeReference("currentCount", LongType)()
private lazy val leftCount = AttributeReference("leftCount", LongType)()
private lazy val rightCount = AttributeReference("rightCount", LongType)()
private lazy val currentDelta = AttributeReference("currentDelta", sumDataType)()
private lazy val currentAvg = AttributeReference("currentAverage", sumDataType)()
private lazy val currentMk = AttributeReference("currentMoment", sumDataType)()

// the values should be updated in a special order, because they re-use each other
override lazy val bufferAttributes =
leftCount :: rightCount :: currentCount :: currentDelta :: currentAvg :: currentMk :: Nil

override lazy val initialValues = Seq(
/* leftCount = */ Literal(0L),
/* rightCount = */ Literal(0L),
/* currentCount = */ Literal(0L),
/* currentDelta = */ Cast(Literal(0), sumDataType),
/* currentAvg = */ Cast(Literal(0), sumDataType),
/* currentMk = */ Cast(Literal(0), sumDataType)
)

override lazy val updateExpressions = {
val currentValue = Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)
val deltaX = Subtract(currentValue, currentAvg)
val updatedCount = If(IsNull(child), currentCount, currentCount + 1L)
val updatedAvg = Add(currentAvg, Divide(currentDelta, currentCount))
Seq(
/* leftCount = */ leftCount, // used only during merging. dummy value
/* rightCount = */ rightCount, // used only during merging. dummy value
/* currentCount = */ updatedCount,
/* currentDelta = */ deltaX,
/* currentAvg = */ updatedAvg,
/* currentMk = */ If(IsNull(child),
currentMk, Add(currentMk, currentDelta * Subtract(currentValue, currentAvg)))
)
}

override lazy val mergeExpressions = {
val totalCount = currentCount.left + currentCount.right
val deltaX = currentAvg.left - currentAvg.right
val deltaX2 = deltaX * deltaX
val sumMoments = currentMk.left + currentMk.right
val sumLeft = currentAvg.left * leftCount
val sumRight = currentAvg.right * rightCount
val mergedAvg = (sumLeft + sumRight) / currentCount
val mergedMk = sumMoments + currentDelta * leftCount / currentCount * rightCount
Seq(
/* leftCount = */ currentCount.left,
/* rightCount = */ currentCount.right,
/* currentCount = */ totalCount,
/* currentDelta = */ deltaX2,
/* currentAvg = */ If(EqualTo(leftCount, Cast(Literal(0L), LongType)), currentAvg.right,
If(EqualTo(rightCount, Cast(Literal(0L), LongType)), currentAvg.left, mergedAvg)),
/* currentMk = */ If(EqualTo(leftCount, Cast(Literal(0L), LongType)), currentMk.right,
If(EqualTo(rightCount, Cast(Literal(0L), LongType)), currentMk.left, mergedMk))
)
}

override lazy val evaluateExpression = {
val count =
if (sample) {
If(EqualTo(currentCount, Cast(Literal(0L), LongType)), currentCount,
currentCount - Cast(Literal(1L), LongType))
} else {
currentCount
}

child.dataType match {
case DecimalType.Fixed(p, s) =>
// increase the precision and scale to prevent precision loss
val dt = DecimalType.bounded(p + 14, s + 4)
Cast(Sqrt(Cast(currentMk, dt) / Cast(count, dt)), resultType)
case _ =>
Sqrt(Cast(currentMk, resultType) / Cast(count, resultType))
}
}

override def prettyName: String = {
if (sample) {
"stddev_samp"
} else {
"stddev_pop"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,28 @@ object Utils {
aggregateFunction = aggregate.Sum(child),
mode = aggregate.Complete,
isDistinct = true)

case hiveUDAF: AggregateExpression1
if hiveUDAF.getClass.getSimpleName == "HiveGenericUDAF" &&
hiveUDAF.toString.contains("GenericUDAFStdSample") =>
// We get a STDDEV_SAMP, which is originally resolved as a HiveGenericUDAF.
require(hiveUDAF.children.length == 1, "stddev_samp requires exactly one argument.")
val child = hiveUDAF.children.head
aggregate.AggregateExpression2(
aggregateFunction = aggregate.StandardDeviation(child, sample = true),
mode = aggregate.Complete,
isDistinct = false)

case hiveUDAF: AggregateExpression1
if hiveUDAF.getClass.getSimpleName == "HiveGenericUDAF" &&
hiveUDAF.toString.contains("GenericUDAFStd") =>
// We get a STDDEV_POP, which is originally resolved as a HiveGenericUDAF.
require(hiveUDAF.children.length == 1, "stddev_pop requires exactly one argument.")
val child = hiveUDAF.children.head
aggregate.AggregateExpression2(
aggregateFunction = aggregate.StandardDeviation(child, sample = false),
mode = aggregate.Complete,
isDistinct = false)
}
// Check if there is any expressions.AggregateExpression1 left.
// If so, we cannot convert this plan.
Expand Down
7 changes: 2 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.StandardDeviation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
Expand Down Expand Up @@ -1268,15 +1269,11 @@ class DataFrame private[sql](
@scala.annotation.varargs
def describe(cols: String*): DataFrame = {

// TODO: Add stddev as an expression, and remove it from here.
def stddevExpr(expr: Expression): Expression =
Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr))))

// The list of summary statistics to compute, in the form of expressions.
val statistics = List[(String, Expression => Expression)](
"count" -> Count,
"mean" -> Average,
"stddev" -> stddevExpr,
"stddev" -> ((e: Expression) => UnresolvedFunction("stddev_samp", e :: Nil, false)),
"min" -> Min,
"max" -> Max)

Expand Down
48 changes: 47 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.JavaConversions._
import scala.language.implicitConversions

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.NumericType
Expand Down Expand Up @@ -283,6 +283,52 @@ class GroupedData protected[sql](
aggregateNumericColumns(colNames : _*)(Min)
}

/**
* Compute the sample standard deviation for each numeric column for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
* When specified columns are given, only compute the standard deviation for them.
*
* @since 1.5.0
*/
@scala.annotation.varargs
def stddev(colNames: String*): DataFrame = {
stddevSamp(colNames : _*)
}

/**
* Compute the population standard deviation for each numeric column for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
* When specified columns are given, only compute the standard deviation for them.
*
* @since 1.5.0
*/
@scala.annotation.varargs
def stddevPop(colNames: String*): DataFrame = {
def builder(e: Expression): Expression = {
Alias(
UnresolvedFunction("stddev_pop", e :: Nil, false),
s"stddev_pop(${e.prettyString})")()
}
aggregateNumericColumns(colNames : _*)(builder)
}

/**
* Compute the sample standard deviation for each numeric column for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
* When specified columns are given, only compute the standard deviation for them.
*
* @since 1.5.0
*/
@scala.annotation.varargs
def stddevSamp(colNames: String*): DataFrame = {
def builder(e: Expression): Expression = {
Alias(
UnresolvedFunction("stddev_samp", e :: Nil, false),
s"stddev_samp(${e.prettyString})")()
}
aggregateNumericColumns(colNames : _*)(builder)
}

/**
* Compute the sum for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ case class SortBasedAggregate(

override def simpleString: String = {
val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}"""

val keyString = groupingExpressions.mkString("[", ",", "]")
val valueString = allAggregateExpressions.mkString("[", ",", "]")
val outputString = output.mkString("[", ",", "]")
s"SortBasedAggregate(key=$keyString, functions=$valueString, output=$outputString)"
}
}
Loading