Skip to content

Commit 0fc4aaa

Browse files
committed
[SPARK-14255][SQL] Streaming Aggregation
This PR adds the ability to perform aggregations inside of a `ContinuousQuery`. In order to implement this feature, the planning of aggregation has augmented with a new `StatefulAggregationStrategy`. Unlike batch aggregation, stateful-aggregation uses the `StateStore` (introduced in apache#11645) to persist the results of partial aggregation across different invocations. The resulting physical plan performs the aggregation using the following progression: - Partial Aggregation - Shuffle - Partial Merge (now there is at most 1 tuple per group) - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) - Partial Merge (now there is at most 1 tuple per group) - StateStoreSave (saves the tuple for the next batch) - Complete (output the current result of the aggregation) The following refactoring was also performed to allow us to plug into existing code: - The get/put implementation is taken from apache#12013 - The logic for breaking down and de-duping the physical execution of aggregation has been move into a new pattern `PhysicalAggregation` - The `AttributeReference` used to identify the result of an `AggregateFunction` as been moved into the `AggregateExpression` container. This change moves the reference into the same object as the other intermediate references used in aggregation and eliminates the need to pass around a `Map[(AggregateFunction, Boolean), Attribute]`. Further clean up (using a different aggregation container for logical/physical plans) is deferred to a followup. - Some planning logic is moved from the `SessionState` into the `QueryExecution` to make it easier to override in the streaming case. - The ability to write a `StreamTest` that checks only the output of the last batch has been added to simulate the future addition of output modes. Author: Michael Armbrust <michael@databricks.com> Closes apache#12048 from marmbrus/statefulAgg.
1 parent 0b7d496 commit 0fc4aaa

File tree

33 files changed

+827
-305
lines changed

33 files changed

+827
-305
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,11 @@ class Analyzer(
336336
Last(ifExpr(expr), Literal(true))
337337
case a: AggregateFunction =>
338338
a.withNewChildren(a.children.map(ifExpr))
339+
}.transform {
340+
// We are duplicating aggregates that are now computing a different value for each
341+
// pivot value.
342+
// TODO: Don't construct the physical container until after analysis.
343+
case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
339344
}
340345
if (filteredAggregate.fastEquals(aggregate)) {
341346
throw new AnalysisException(
@@ -1153,11 +1158,11 @@ class Analyzer(
11531158

11541159
// Extract Windowed AggregateExpression
11551160
case we @ WindowExpression(
1156-
AggregateExpression(function, mode, isDistinct),
1161+
ae @ AggregateExpression(function, _, _, _),
11571162
spec: WindowSpecDefinition) =>
11581163
val newChildren = function.children.map(extractExpr)
11591164
val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
1160-
val newAgg = AggregateExpression(newFunction, mode, isDistinct)
1165+
val newAgg = ae.copy(aggregateFunction = newFunction)
11611166
seenWindowAggregates += newAgg
11621167
WindowExpression(newAgg, spec)
11631168

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ trait CheckAnalysis {
7676
case g: GroupingID =>
7777
failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup")
7878

79-
case w @ WindowExpression(AggregateExpression(_, _, true), _) =>
79+
case w @ WindowExpression(AggregateExpression(_, _, true, _), _) =>
8080
failAnalysis(s"Distinct window functions are not supported: $w")
8181

8282
case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
2525
package object errors {
2626

2727
class TreeNodeException[TreeType <: TreeNode[_]](
28-
tree: TreeType, msg: String, cause: Throwable)
28+
@transient val tree: TreeType,
29+
msg: String,
30+
cause: Throwable)
2931
extends Exception(msg, cause) {
3032

33+
val treeString = tree.toString
34+
3135
// Yes, this is the same as a default parameter, but... those don't seem to work with SBT
3236
// external project dependencies for some reason.
3337
def this(tree: TreeType, msg: String) = this(tree, msg, null)
3438

3539
override def getMessage: String = {
36-
val treeString = tree.toString
3740
s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree"
3841
}
3942
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2122
import org.apache.spark.sql.catalyst.expressions._
2223
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2324
import org.apache.spark.sql.types._
@@ -66,17 +67,51 @@ private[sql] case object NoOp extends Expression with Unevaluable {
6667
override def children: Seq[Expression] = Nil
6768
}
6869

70+
object AggregateExpression {
71+
def apply(
72+
aggregateFunction: AggregateFunction,
73+
mode: AggregateMode,
74+
isDistinct: Boolean): AggregateExpression = {
75+
AggregateExpression(
76+
aggregateFunction,
77+
mode,
78+
isDistinct,
79+
NamedExpression.newExprId)
80+
}
81+
}
82+
6983
/**
7084
* A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field
7185
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
7286
*/
7387
private[sql] case class AggregateExpression(
7488
aggregateFunction: AggregateFunction,
7589
mode: AggregateMode,
76-
isDistinct: Boolean)
90+
isDistinct: Boolean,
91+
resultId: ExprId)
7792
extends Expression
7893
with Unevaluable {
7994

95+
lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) {
96+
AttributeReference(
97+
aggregateFunction.toString,
98+
aggregateFunction.dataType,
99+
aggregateFunction.nullable)(exprId = resultId)
100+
} else {
101+
// This is a bit of a hack. Really we should not be constructing this container and reasoning
102+
// about datatypes / aggregation mode until after we have finished analysis and made it to
103+
// planning.
104+
UnresolvedAttribute(aggregateFunction.toString)
105+
}
106+
107+
// We compute the same thing regardless of our final result.
108+
override lazy val canonicalized: Expression =
109+
AggregateExpression(
110+
aggregateFunction.canonicalized.asInstanceOf[AggregateFunction],
111+
mode,
112+
isDistinct,
113+
ExprId(0))
114+
80115
override def children: Seq[Expression] = aggregateFunction :: Nil
81116
override def dataType: DataType = aggregateFunction.dataType
82117
override def foldable: Boolean = false

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ case class PrettyAttribute(
329329
override def withName(newName: String): Attribute = throw new UnsupportedOperationException
330330
override def qualifier: Option[String] = throw new UnsupportedOperationException
331331
override def exprId: ExprId = throw new UnsupportedOperationException
332-
override def nullable: Boolean = throw new UnsupportedOperationException
332+
override def nullable: Boolean = true
333333
}
334334

335335
object VirtualColumn {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ object NullPropagation extends Rule[LogicalPlan] {
534534

535535
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
536536
case q: LogicalPlan => q transformExpressionsUp {
537-
case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
537+
case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
538538
Cast(Literal(0L), e.dataType)
539539
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
540540
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@@ -547,9 +547,9 @@ object NullPropagation extends Rule[LogicalPlan] {
547547
Literal.create(null, e.dataType)
548548
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
549549
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
550-
case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
550+
case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
551551
// This rule should be only triggered when isDistinct field is false.
552-
AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
552+
ae.copy(aggregateFunction = Count(Literal(1)))
553553

554554
// For Coalesce, remove null literals.
555555
case e @ Coalesce(children) =>
@@ -1225,13 +1225,13 @@ object DecimalAggregates extends Rule[LogicalPlan] {
12251225
private val MAX_DOUBLE_DIGITS = 15
12261226

12271227
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
1228-
case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
1228+
case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _)
12291229
if prec + 10 <= MAX_LONG_DIGITS =>
1230-
MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale)
1230+
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
12311231

1232-
case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
1232+
case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _)
12331233
if prec + 4 <= MAX_DOUBLE_DIGITS =>
1234-
val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct)
1234+
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
12351235
Cast(
12361236
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
12371237
DecimalType(prec + 4, scale + 4))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.mutable
2222

2323
import org.apache.spark.internal.Logging
2424
import org.apache.spark.sql.catalyst.expressions._
25+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2526
import org.apache.spark.sql.catalyst.plans._
2627
import org.apache.spark.sql.catalyst.plans.logical._
2728
import org.apache.spark.sql.types.IntegerType
@@ -216,3 +217,75 @@ object IntegerIndex {
216217
case _ => None
217218
}
218219
}
220+
221+
/**
222+
* An extractor used when planning the physical execution of an aggregation. Compared with a logical
223+
* aggregation, the following transformations are performed:
224+
* - Unnamed grouping expressions are named so that they can be referred to across phases of
225+
* aggregation
226+
* - Aggregations that appear multiple times are deduplicated.
227+
* - The compution of the aggregations themselves is separated from the final result. For example,
228+
* the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final
229+
* computation that computes `count.resultAttribute + 1`.
230+
*/
231+
object PhysicalAggregation {
232+
// groupingExpressions, aggregateExpressions, resultExpressions, child
233+
type ReturnType =
234+
(Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan)
235+
236+
def unapply(a: Any): Option[ReturnType] = a match {
237+
case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
238+
// A single aggregate expression might appear multiple times in resultExpressions.
239+
// In order to avoid evaluating an individual aggregate function multiple times, we'll
240+
// build a set of the distinct aggregate expressions and build a function which can
241+
// be used to re-write expressions so that they reference the single copy of the
242+
// aggregate function which actually gets computed.
243+
val aggregateExpressions = resultExpressions.flatMap { expr =>
244+
expr.collect {
245+
case agg: AggregateExpression => agg
246+
}
247+
}.distinct
248+
249+
val namedGroupingExpressions = groupingExpressions.map {
250+
case ne: NamedExpression => ne -> ne
251+
// If the expression is not a NamedExpressions, we add an alias.
252+
// So, when we generate the result of the operator, the Aggregate Operator
253+
// can directly get the Seq of attributes representing the grouping expressions.
254+
case other =>
255+
val withAlias = Alias(other, other.toString)()
256+
other -> withAlias
257+
}
258+
val groupExpressionMap = namedGroupingExpressions.toMap
259+
260+
// The original `resultExpressions` are a set of expressions which may reference
261+
// aggregate expressions, grouping column values, and constants. When aggregate operator
262+
// emits output rows, we will use `resultExpressions` to generate an output projection
263+
// which takes the grouping columns and final aggregate result buffer as input.
264+
// Thus, we must re-write the result expressions so that their attributes match up with
265+
// the attributes of the final result projection's input row:
266+
val rewrittenResultExpressions = resultExpressions.map { expr =>
267+
expr.transformDown {
268+
case ae: AggregateExpression =>
269+
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
270+
// so replace each aggregate expression by its corresponding attribute in the set:
271+
ae.resultAttribute
272+
case expression =>
273+
// Since we're using `namedGroupingAttributes` to extract the grouping key
274+
// columns, we need to replace grouping key expressions with their corresponding
275+
// attributes. We do not rely on the equality check at here since attributes may
276+
// differ cosmetically. Instead, we use semanticEquals.
277+
groupExpressionMap.collectFirst {
278+
case (expr, ne) if expr semanticEquals expression => ne.toAttribute
279+
}.getOrElse(expression)
280+
}.asInstanceOf[NamedExpression]
281+
}
282+
283+
Some((
284+
namedGroupingExpressions.map(_._2),
285+
aggregateExpressions,
286+
rewrittenResultExpressions,
287+
child))
288+
289+
case _ => None
290+
}
291+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2223
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
2324
import org.apache.spark.sql.catalyst.util._
2425

@@ -38,6 +39,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
3839
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
3940
case a: Alias =>
4041
Alias(a.child, a.name)(exprId = ExprId(0))
42+
case ae: AggregateExpression =>
43+
ae.copy(resultId = ExprId(0))
4144
}
4245
}
4346

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import org.apache.spark.rdd.RDD
2121
import org.apache.spark.sql.{AnalysisException, SQLContext}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
2426

2527
/**
2628
* The primary workflow for executing relational queries using Spark. Designed to allow easy
@@ -31,6 +33,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
3133
*/
3234
class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
3335

36+
// TODO: Move the planner an optimizer into here from SessionState.
37+
protected def planner = sqlContext.sessionState.planner
38+
3439
def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch {
3540
case e: AnalysisException =>
3641
val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed))
@@ -49,16 +54,31 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
4954

5055
lazy val sparkPlan: SparkPlan = {
5156
SQLContext.setActive(sqlContext)
52-
sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next()
57+
planner.plan(ReturnAnswer(optimizedPlan)).next()
5358
}
5459

5560
// executedPlan should not be used to initialize any SparkPlan. It should be
5661
// only used for execution.
57-
lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan)
62+
lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
5863

5964
/** Internal version of the RDD. Avoids copies and has no schema */
6065
lazy val toRdd: RDD[InternalRow] = executedPlan.execute()
6166

67+
/**
68+
* Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal
69+
* row format conversions as needed.
70+
*/
71+
protected def prepareForExecution(plan: SparkPlan): SparkPlan = {
72+
preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
73+
}
74+
75+
/** A sequence of rules that will be applied in order to the physical plan before execution. */
76+
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
77+
PlanSubqueries(sqlContext),
78+
EnsureRequirements(sqlContext.conf),
79+
CollapseCodegenStages(sqlContext.conf),
80+
ReuseExchange(sqlContext.conf))
81+
6282
protected def stringOrError[A](f: => A): String =
6383
try f.toString catch { case e: Throwable => e.toString }
6484

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,13 @@ private[sql] trait LeafNode extends SparkPlan {
379379
override def producedAttributes: AttributeSet = outputSet
380380
}
381381

382+
object UnaryNode {
383+
def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match {
384+
case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head))
385+
case _ => None
386+
}
387+
}
388+
382389
private[sql] trait UnaryNode extends SparkPlan {
383390
def child: SparkPlan
384391

0 commit comments

Comments
 (0)