Skip to content

SPARK-1627: Support external aggregation by using Aggregator in Spark SQL #867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions core/src/main/scala/org/apache/spark/Aggregator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap}
import org.apache.spark.serializer.Serializer

/**
* :: DeveloperApi ::
Expand All @@ -27,12 +28,14 @@ import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap}
* @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
* @param serializer serializer to persist data internally.
*/
@DeveloperApi
case class Aggregator[K, V, C] (
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
mergeCombiners: (C, C) => C,
serializer: Serializer = SparkEnv.get.serializer) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also update the documentation above to add the new parameter.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


private val externalSorting = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)

Expand All @@ -54,7 +57,8 @@ case class Aggregator[K, V, C] (
}
combiners.iterator
} else {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
val combiners = new ExternalAppendOnlyMap[K, V, C](
createCombiner, mergeValue, mergeCombiners, serializer)
while (iter.hasNext) {
val (k, v) = iter.next()
combiners.insert(k, v)
Expand Down Expand Up @@ -83,7 +87,8 @@ case class Aggregator[K, V, C] (
}
combiners.iterator
} else {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
val combiners = new ExternalAppendOnlyMap[K, C, C](
identity, mergeCombiners, mergeCombiners, serializer)
while (iter.hasNext) {
val (k, c) = iter.next()
combiners.insert(k, c)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ abstract class AggregateFunction
override def dataType = base.dataType

def update(input: Row): Unit
def merge(other: AggregateFunction): Unit
override def eval(input: Row): Any

// Do we really need this?
Expand Down Expand Up @@ -189,6 +190,16 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
count += 1
sum.update(addFunction, input)
}

override def merge(other: AggregateFunction): Unit = {
other match {
case avg: AverageFunction => {
count += avg.count
sum.update(Add(sum, avg.sum), EmptyRow)
}
case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}")
}
}
}

case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
Expand All @@ -203,6 +214,15 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
}
}

override def merge(other: AggregateFunction): Unit = {
other match {
case c: CountFunction => {
count += c.count
}
case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}")
}
}

override def eval(input: Row): Any = count
}

Expand All @@ -217,6 +237,15 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
sum.update(addFunction, input)
}

override def merge(other: AggregateFunction): Unit = {
other match {
case s: SumFunction => {
sum.update(Add(sum, s.sum), EmptyRow)
}
case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}")
}
}

override def eval(input: Row): Any = sum.eval(null)
}

Expand All @@ -234,6 +263,19 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
}
}

override def merge(other: AggregateFunction): Unit = {
other match {
case sd: SumDistinctFunction => {
// TODO(lamuguo): Change to HashSet union scala rebase to support it. Related change:
// https://github.com/scala/scala/pull/3322
for (item <- sd.seen) {
seen += item
}
}
case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}")
}
}

override def eval(input: Row): Any =
seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
}
Expand All @@ -252,6 +294,17 @@ case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpressio
}
}

override def merge(other: AggregateFunction): Unit = {
other match {
case cd: CountDistinctFunction => {
for (item <- cd.seen) {
seen += item
}
}
case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}")
}
}

override def eval(input: Row): Any = seen.size
}

Expand All @@ -266,5 +319,16 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
}
}

override def merge(other: AggregateFunction): Unit = {
other match {
case second: FirstFunction => {
if (result == null) {
result = second.result
}
}
case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}")
}
}

override def eval(input: Row): Any = result
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

package org.apache.spark.sql.execution

import java.util.HashMap

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext
import org.apache.spark.{Logging, SparkConf, Aggregator, SparkContext}
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.SparkSqlSerializer
import scala.collection.mutable.ArrayBuffer

/**
* :: DeveloperApi ::
Expand All @@ -42,7 +42,7 @@ case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)(@transient sc: SparkContext)
extends UnaryNode with NoBind {
extends UnaryNode with NoBind with Logging {

override def requiredChildDistribution =
if (partial) {
Expand Down Expand Up @@ -155,48 +155,63 @@ case class Aggregate(
}
} else {
child.execute().mapPartitions { iter =>
val hashTable = new HashMap[Row, Array[AggregateFunction]]
val groupingProjection = new MutableProjection(groupingExpressions, childOutput)

var currentRow: Row = null
while (iter.hasNext) {
currentRow = iter.next()
val currentGroup = groupingProjection(currentRow)
var currentBuffer = hashTable.get(currentGroup)
if (currentBuffer == null) {
currentBuffer = newAggregateBuffer()
hashTable.put(currentGroup.copy(), currentBuffer)
def createCombiner(row: Row) = mergeValue(newAggregateBuffer(), row)
def mergeValue(buffer: Array[AggregateFunction], row: Row) = {
var i = 0
while (i < buffer.length) {
buffer(i).update(row)
i += 1
}
buffer
}
def mergeCombiners(buf1: Array[AggregateFunction], buf2: Array[AggregateFunction]) = {
if (buf1.length != buf2.length) {
throw new TreeNodeException(this, s"Unequal aggregate buffer length ${buf1.length} != ${buf2.length}")
}

var i = 0
while (i < currentBuffer.length) {
currentBuffer(i).update(currentRow)
while (i < buf1.length) {
buf1(i).merge(buf2(i))
i += 1
}
buf1
}

val aggregator = new Aggregator[Row, Row, Array[AggregateFunction]](
createCombiner, mergeValue, mergeCombiners, new SparkSqlSerializer(new SparkConf(false)))

val aggIter = aggregator.combineValuesByKey(
new Iterator[(Row, Row)] { // (groupKey, row)
override final def hasNext: Boolean = iter.hasNext

override final def next(): (Row, Row) = {
val row = iter.next()
// TODO: copy() here for suppressing reference problems. Please clearly address
// the root-cause and remove copy() here.
(groupingProjection(row).copy(), row)
}
},
null
)
new Iterator[Row] {
private[this] val hashTableIter = hashTable.entrySet().iterator()
private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
private[this] val resultProjection =
new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2))
private[this] val resultProjection = new MutableProjection(
resultExpressions, computedSchema ++ namedGroups.map(_._2))
private[this] val joinedRow = new JoinedRow

override final def hasNext: Boolean = hashTableIter.hasNext
override final def hasNext: Boolean = aggIter.hasNext

override final def next(): Row = {
val currentEntry = hashTableIter.next()
val currentGroup = currentEntry.getKey
val currentBuffer = currentEntry.getValue
val entry = aggIter.next()
val group = entry._1
val data = entry._2

var i = 0
while (i < currentBuffer.length) {
// Evaluating an aggregate buffer returns the result. No row is required since we
// already added all rows in the group using update.
aggregateResults(i) = currentBuffer(i).eval(EmptyRow)
while (i < data.length) {
aggregateResults(i) = data(i).eval(EmptyRow)
i += 1
}
resultProjection(joinedRow(aggregateResults, currentGroup))

resultProjection(joinedRow(aggregateResults, group))
}
}
}
Expand Down