From 68d3388001615841685e9942c4220d7904f33665 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Thu, 12 Nov 2020 13:01:48 +0800 Subject: [PATCH] Optimize code --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 3 ++- .../optimizer/OptimizeWindowFunctionsSuite.scala | 13 ++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 57ddd392e180a..87eb0be77fcee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -813,7 +813,8 @@ object CollapseRepartition extends Rule[LogicalPlan] { object OptimizeWindowFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case we @ WindowExpression(AggregateExpression(first: First, _, _, _, _), spec) - if !spec.orderSpec.isEmpty => + if spec.orderSpec.nonEmpty && + spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame].frameType == RowFrame => we.copy(windowFunction = NthValue(first.child, Literal(1), first.ignoreNulls)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala index c89208dce45d6..dfe1d47bcba06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala @@ -36,7 +36,7 @@ class OptimizeWindowFunctionsSuite extends PlanTest { val b = testRelation.output(1) val c = testRelation.output(2) - test("replace first(col) by nth_value(col, 1) if the window frame is ordered") { + test("replace first(col) by nth_value(col, 1)") { val inputPlan = testRelation.select( WindowExpression( First(a, false).toAggregateExpression(), @@ -52,6 +52,17 @@ class OptimizeWindowFunctionsSuite extends PlanTest { assert(optimized == correctAnswer) } + test("can't replace first(col) by nth_value(col, 1) if the window frame type is row") { + val inputPlan = testRelation.select( + WindowExpression( + First(a, false).toAggregateExpression(), + WindowSpecDefinition(b :: Nil, c.asc :: Nil, + SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow)))) + + val optimized = Optimize.execute(inputPlan) + assert(optimized == inputPlan) + } + test("can't replace first(col) by nth_value(col, 1) if the window frame isn't ordered") { val inputPlan = testRelation.select( WindowExpression(