Skip to content

Commit 9cba8a3

Browse files
committed
Merge remote-tracking branch 'upstream/master' into compoundNullFilter
2 parents 7b314e1 + e33bc67 commit 9cba8a3

File tree

13 files changed

+331
-67
lines changed

13 files changed

+331
-67
lines changed

core/src/main/scala/org/apache/spark/SparkConf.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ private[spark] object SparkConf extends Logging {
718718
allAlternatives.get(key).foreach { case (newKey, cfg) =>
719719
logWarning(
720720
s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " +
721-
s"and may be removed in the future. Please use the new key '$newKey' instead.")
721+
s"may be removed in the future. Please use the new key '$newKey' instead.")
722722
return
723723
}
724724
if (key.startsWith("spark.akka") || key.startsWith("spark.ssl.akka")) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ class Analyzer(
421421
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
422422
(oldVersion, oldVersion.copy(generatorOutput = newOutput))
423423

424-
case oldVersion @ Window(_, windowExpressions, _, _, child)
424+
case oldVersion @ Window(windowExpressions, _, _, child)
425425
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
426426
.nonEmpty =>
427427
(oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
@@ -658,10 +658,6 @@ class Analyzer(
658658
case p: Project =>
659659
val missing = missingAttrs -- p.child.outputSet
660660
Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing))
661-
case w: Window =>
662-
val missing = missingAttrs -- w.child.outputSet
663-
w.copy(projectList = w.projectList ++ missingAttrs,
664-
child = addMissingAttr(w.child, missing))
665661
case a: Aggregate =>
666662
// all the missing attributes should be grouping expressions
667663
// TODO: push down AggregateExpression
@@ -1166,7 +1162,6 @@ class Analyzer(
11661162
// Set currentChild to the newly created Window operator.
11671163
currentChild =
11681164
Window(
1169-
currentChild.output,
11701165
windowExpressions,
11711166
partitionSpec,
11721167
orderSpec,
@@ -1199,7 +1194,7 @@ class Analyzer(
11991194
val withWindow = addWindow(windowExpressions, withFilter)
12001195

12011196
// Finally, generate output columns according to the original projectList.
1202-
val finalProjectList = aggregateExprs.map (_.toAttribute)
1197+
val finalProjectList = aggregateExprs.map(_.toAttribute)
12031198
Project(finalProjectList, withWindow)
12041199

12051200
case p: LogicalPlan if !p.childrenResolved => p
@@ -1215,7 +1210,7 @@ class Analyzer(
12151210
val withWindow = addWindow(windowExpressions, withAggregate)
12161211

12171212
// Finally, generate output columns according to the original projectList.
1218-
val finalProjectList = aggregateExprs.map (_.toAttribute)
1213+
val finalProjectList = aggregateExprs.map(_.toAttribute)
12191214
Project(finalProjectList, withWindow)
12201215

12211216
// We only extract Window Expressions after all expressions of the Project
@@ -1230,7 +1225,7 @@ class Analyzer(
12301225
val withWindow = addWindow(windowExpressions, withProject)
12311226

12321227
// Finally, generate output columns according to the original projectList.
1233-
val finalProjectList = projectList.map (_.toAttribute)
1228+
val finalProjectList = projectList.map(_.toAttribute)
12341229
Project(finalProjectList, withWindow)
12351230
}
12361231
}
@@ -1436,10 +1431,10 @@ object CleanupAliases extends Rule[LogicalPlan] {
14361431
val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
14371432
Aggregate(grouping.map(trimAliases), cleanedAggs, child)
14381433

1439-
case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
1434+
case w @ Window(windowExprs, partitionSpec, orderSpec, child) =>
14401435
val cleanedWindowExprs =
14411436
windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression])
1442-
Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
1437+
Window(cleanedWindowExprs, partitionSpec.map(trimAliases),
14431438
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
14441439

14451440
// Operators that operate on objects should only have expressions from encoders, which should

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,12 @@ package object dsl {
268268
Aggregate(groupingExprs, aliasedExprs, logicalPlan)
269269
}
270270

271+
def window(
272+
windowExpressions: Seq[NamedExpression],
273+
partitionSpec: Seq[Expression],
274+
orderSpec: Seq[SortOrder]): LogicalPlan =
275+
Window(windowExpressions, partitionSpec, orderSpec, logicalPlan)
276+
271277
def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan)
272278

273279
def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ case class SortOrder(child: Expression, direction: SortDirection)
5757
override def dataType: DataType = child.dataType
5858
override def nullable: Boolean = child.nullable
5959

60-
override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
60+
override def toString: String = s"$child ${direction.sql}"
61+
override def sql: String = child.sql + " " + direction.sql
6162

6263
def isAscending: Boolean = direction == Ascending
6364
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.AnalysisException
21-
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
21+
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException}
22+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2223
import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp}
2324
import org.apache.spark.sql.types._
2425

@@ -30,6 +31,7 @@ sealed trait WindowSpec
3031

3132
/**
3233
* The specification for a window function.
34+
*
3335
* @param partitionSpec It defines the way that input rows are partitioned.
3436
* @param orderSpec It defines the ordering of rows in a partition.
3537
* @param frameSpecification It defines the window frame in a partition.
@@ -75,6 +77,22 @@ case class WindowSpecDefinition(
7577
override def nullable: Boolean = true
7678
override def foldable: Boolean = false
7779
override def dataType: DataType = throw new UnsupportedOperationException
80+
81+
override def sql: String = {
82+
val partition = if (partitionSpec.isEmpty) {
83+
""
84+
} else {
85+
"PARTITION BY " + partitionSpec.map(_.sql).mkString(", ")
86+
}
87+
88+
val order = if (orderSpec.isEmpty) {
89+
""
90+
} else {
91+
"ORDER BY " + orderSpec.map(_.sql).mkString(", ")
92+
}
93+
94+
s"($partition $order ${frameSpecification.toString})"
95+
}
7896
}
7997

8098
/**
@@ -278,6 +296,7 @@ case class WindowExpression(
278296
override def nullable: Boolean = windowFunction.nullable
279297

280298
override def toString: String = s"$windowFunction $windowSpec"
299+
override def sql: String = windowFunction.sql + " OVER " + windowSpec.sql
281300
}
282301

283302
/**
@@ -451,6 +470,7 @@ object SizeBasedWindowFunction {
451470
the window partition.""")
452471
case class RowNumber() extends RowNumberLike {
453472
override val evaluateExpression = rowNumber
473+
override def sql: String = "ROW_NUMBER()"
454474
}
455475

456476
/**
@@ -470,6 +490,7 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction {
470490
// return the same value for equal values in the partition.
471491
override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow)
472492
override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType))
493+
override def sql: String = "CUME_DIST()"
473494
}
474495

475496
/**
@@ -499,12 +520,25 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction {
499520
case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction {
500521
def this() = this(Literal(1))
501522

523+
override def children: Seq[Expression] = Seq(buckets)
524+
502525
// Validate buckets. Note that this could be relaxed, the bucket value only needs to constant
503526
// for each partition.
504-
buckets.eval() match {
505-
case b: Int if b > 0 => // Ok
506-
case x => throw new AnalysisException(
507-
"Buckets expression must be a foldable positive integer expression: $x")
527+
override def checkInputDataTypes(): TypeCheckResult = {
528+
if (!buckets.foldable) {
529+
return TypeCheckFailure(s"Buckets expression must be foldable, but got $buckets")
530+
}
531+
532+
if (buckets.dataType != IntegerType) {
533+
return TypeCheckFailure(s"Buckets expression must be integer type, but got $buckets")
534+
}
535+
536+
val i = buckets.eval().asInstanceOf[Int]
537+
if (i > 0) {
538+
TypeCheckSuccess
539+
} else {
540+
TypeCheckFailure(s"Buckets expression must be positive, but got: $i")
541+
}
508542
}
509543

510544
private val bucket = AttributeReference("bucket", IntegerType, nullable = false)()
@@ -608,6 +642,7 @@ abstract class RankLike extends AggregateWindowFunction {
608642
case class Rank(children: Seq[Expression]) extends RankLike {
609643
def this() = this(Nil)
610644
override def withOrder(order: Seq[Expression]): Rank = Rank(order)
645+
override def sql: String = "RANK()"
611646
}
612647

613648
/**
@@ -632,6 +667,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike {
632667
override val updateExpressions = increaseRank +: children
633668
override val aggBufferAttributes = rank +: orderAttrs
634669
override val initialValues = zero +: orderInit
670+
override def sql: String = "DENSE_RANK()"
635671
}
636672

637673
/**
@@ -658,4 +694,5 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase
658694
override val evaluateExpression = If(GreaterThan(n, one),
659695
Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)),
660696
Literal(0.0d))
697+
override def sql: String = "PERCENT_RANK()"
661698
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -315,21 +315,17 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
315315
* - LeftSemiJoin
316316
*/
317317
object ColumnPruning extends Rule[LogicalPlan] {
318-
def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
318+
private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
319319
output1.size == output2.size &&
320320
output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
321321

322322
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
323-
// Prunes the unused columns from project list of Project/Aggregate/Window/Expand
323+
// Prunes the unused columns from project list of Project/Aggregate/Expand
324324
case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
325325
p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
326326
case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty =>
327327
p.copy(
328328
child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
329-
case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty =>
330-
p.copy(child = w.copy(
331-
projectList = w.projectList.filter(p.references.contains),
332-
windowExpressions = w.windowExpressions.filter(p.references.contains)))
333329
case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty =>
334330
val newOutput = e.output.filter(a.references.contains(_))
335331
val newProjects = e.projections.map { proj =>
@@ -343,11 +339,9 @@ object ColumnPruning extends Rule[LogicalPlan] {
343339
case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty =>
344340
mp.copy(child = prunedChild(child, mp.references))
345341

346-
// Prunes the unused columns from child of Aggregate/Window/Expand/Generate
342+
// Prunes the unused columns from child of Aggregate/Expand/Generate
347343
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
348344
a.copy(child = prunedChild(child, a.references))
349-
case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty =>
350-
w.copy(child = prunedChild(child, w.references))
351345
case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
352346
e.copy(child = prunedChild(child, e.references))
353347
case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
@@ -381,6 +375,14 @@ object ColumnPruning extends Rule[LogicalPlan] {
381375
p
382376
}
383377

378+
// Prune unnecessary window expressions
379+
case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty =>
380+
p.copy(child = w.copy(
381+
windowExpressions = w.windowExpressions.filter(p.references.contains)))
382+
383+
// Eliminate no-op Window
384+
case w: Window if w.windowExpressions.isEmpty => w.child
385+
384386
// Eliminate no-op Projects
385387
case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child
386388

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,14 +434,15 @@ case class Aggregate(
434434
}
435435

436436
case class Window(
437-
projectList: Seq[Attribute],
438437
windowExpressions: Seq[NamedExpression],
439438
partitionSpec: Seq[Expression],
440439
orderSpec: Seq[SortOrder],
441440
child: LogicalPlan) extends UnaryNode {
442441

443442
override def output: Seq[Attribute] =
444-
projectList ++ windowExpressions.map(_.toAttribute)
443+
child.output ++ windowExpressions.map(_.toAttribute)
444+
445+
def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute))
445446
}
446447

447448
private[sql] object Expand {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.analysis
2323
import org.apache.spark.sql.catalyst.dsl.expressions._
2424
import org.apache.spark.sql.catalyst.dsl.plans._
2525
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
26-
import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder}
26+
import org.apache.spark.sql.catalyst.expressions._
27+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}
2728
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
2829
import org.apache.spark.sql.catalyst.plans.logical._
2930
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -33,7 +34,8 @@ class ColumnPruningSuite extends PlanTest {
3334

3435
object Optimize extends RuleExecutor[LogicalPlan] {
3536
val batches = Batch("Column pruning", FixedPoint(100),
36-
ColumnPruning) :: Nil
37+
ColumnPruning,
38+
CollapseProject) :: Nil
3739
}
3840

3941
test("Column pruning for Generate when Generate.join = false") {
@@ -258,6 +260,68 @@ class ColumnPruningSuite extends PlanTest {
258260
comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
259261
}
260262

263+
test("Column pruning on Window with useless aggregate functions") {
264+
val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int)
265+
266+
val originalQuery =
267+
input.groupBy('a, 'c, 'd)('a, 'c, 'd,
268+
WindowExpression(
269+
AggregateExpression(Count('b), Complete, isDistinct = false),
270+
WindowSpecDefinition( 'a :: Nil,
271+
SortOrder('b, Ascending) :: Nil,
272+
UnspecifiedFrame)).as('window)).select('a, 'c)
273+
274+
val correctAnswer = input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze
275+
276+
val optimized = Optimize.execute(originalQuery.analyze)
277+
278+
comparePlans(optimized, correctAnswer)
279+
}
280+
281+
test("Column pruning on Window with selected agg expressions") {
282+
val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int)
283+
284+
val originalQuery =
285+
input.select('a, 'b, 'c, 'd,
286+
WindowExpression(
287+
AggregateExpression(Count('b), Complete, isDistinct = false),
288+
WindowSpecDefinition( 'a :: Nil,
289+
SortOrder('b, Ascending) :: Nil,
290+
UnspecifiedFrame)).as('window)).where('window > 1).select('a, 'c)
291+
292+
val correctAnswer =
293+
input.select('a, 'b, 'c)
294+
.window(WindowExpression(
295+
AggregateExpression(Count('b), Complete, isDistinct = false),
296+
WindowSpecDefinition( 'a :: Nil,
297+
SortOrder('b, Ascending) :: Nil,
298+
UnspecifiedFrame)).as('window) :: Nil,
299+
'a :: Nil, 'b.asc :: Nil)
300+
.select('a, 'c, 'window).where('window > 1).select('a, 'c).analyze
301+
302+
val optimized = Optimize.execute(originalQuery.analyze)
303+
304+
comparePlans(optimized, correctAnswer)
305+
}
306+
307+
test("Column pruning on Window in select") {
308+
val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int)
309+
310+
val originalQuery =
311+
input.select('a, 'b, 'c, 'd,
312+
WindowExpression(
313+
AggregateExpression(Count('b), Complete, isDistinct = false),
314+
WindowSpecDefinition( 'a :: Nil,
315+
SortOrder('b, Ascending) :: Nil,
316+
UnspecifiedFrame)).as('window)).select('a, 'c)
317+
318+
val correctAnswer = input.select('a, 'c).analyze
319+
320+
val optimized = Optimize.execute(originalQuery.analyze)
321+
322+
comparePlans(optimized, correctAnswer)
323+
}
324+
261325
test("Column pruning on Union") {
262326
val input1 = LocalRelation('a.int, 'b.string, 'c.double)
263327
val input2 = LocalRelation('c.int, 'd.string, 'e.double)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
344344
execution.Filter(condition, planLater(child)) :: Nil
345345
case e @ logical.Expand(_, _, child) =>
346346
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
347-
case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
348-
execution.Window(
349-
projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
347+
case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
348+
execution.Window(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
350349
case logical.Sample(lb, ub, withReplacement, seed, child) =>
351350
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
352351
case logical.LocalRelation(output, data) =>

0 commit comments

Comments
 (0)