Skip to content

Commit 4b3a653

Browse files
mihailotim-dbcloud-fan
authored andcommitted
[SPARK-51820][SQL][FOLLOWUP][CONNECT] Don't add UnresolvedOrdinal when appending grouping to aggregate expressions
### What changes were proposed in this pull request? Don't add `UnresolvedOrdinal` when appending grouping to aggregate expressions in Spark Connect. ### Why are the changes needed? Change is needed to fix a regression caused by #50606 where `UnresolvedOrdinal` would end up in aggregate expression and propagate all the way to `CheckAnalysis` where it would throw an error ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added a test case for the correct behavior. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50958 from mihailotim-db/mihailotim-db/fix_unresolved_ordinal. Authored-by: Mihailo Timotic <mihailo.timotic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 9ce5914 commit 4b3a653

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2457,11 +2457,13 @@ class SparkConnectPlanner(
24572457
input
24582458
}
24592459

2460-
val groupingExpressionsWithOrdinals = rel.getGroupingExpressionsList.asScala.toSeq
2461-
.map(transformGroupingExpressionAndReplaceOrdinals)
2460+
val groupingExpressions =
2461+
rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression)
2462+
val groupingExpressionsWithOrdinals =
2463+
groupingExpressions.map(replaceOrdinalsInGroupingExpressions)
24622464
val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq
24632465
.map(expr => transformExpressionWithTypedReduceExpression(expr, logicalPlan))
2464-
val aliasedAgg = (groupingExpressionsWithOrdinals ++ aggExprs).map(toNamedExpression)
2466+
val aliasedAgg = (groupingExpressions ++ aggExprs).map(toNamedExpression)
24652467

24662468
rel.getGroupType match {
24672469
case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
@@ -2506,7 +2508,10 @@ class SparkConnectPlanner(
25062508
val groupingSetsExpressionsWithOrdinals =
25072509
rel.getGroupingSetsList.asScala.toSeq.map { getGroupingSets =>
25082510
getGroupingSets.getGroupingSetList.asScala.toSeq
2509-
.map(transformGroupingExpressionAndReplaceOrdinals)
2511+
.map(groupingExpressions => {
2512+
val transformedGroupingExpression = transformExpression(groupingExpressions)
2513+
replaceOrdinalsInGroupingExpressions(transformedGroupingExpression)
2514+
})
25102515
}
25112516
logical.Aggregate(
25122517
groupingExpressions = Seq(
@@ -2521,18 +2526,15 @@ class SparkConnectPlanner(
25212526
}
25222527

25232528
/**
2524-
* Transforms an input protobuf grouping expression into the Catalyst expression and converts
2525-
* top-level integer [[Literal]]s to [[UnresolvedOrdinal]]s, if `groupByOrdinal` is enabled.
2529+
* Replaces top-level integer [[Literal]]s to [[UnresolvedOrdinal]]s, if `groupByOrdinal` is
2530+
* enabled.
25262531
*/
2527-
private def transformGroupingExpressionAndReplaceOrdinals(
2528-
groupingExpression: proto.Expression) = {
2529-
val transformedGroupingExpression = transformExpression(groupingExpression)
2532+
private def replaceOrdinalsInGroupingExpressions(groupingExpression: Expression) =
25302533
if (session.sessionState.conf.groupByOrdinal) {
2531-
replaceIntegerLiteralWithOrdinal(transformedGroupingExpression)
2534+
replaceIntegerLiteralWithOrdinal(groupingExpression)
25322535
} else {
2533-
transformedGroupingExpression
2536+
groupingExpression
25342537
}
2535-
}
25362538

25372539
@deprecated("TypedReduce is now implemented using a normal UDAF aggregator.", "4.0.0")
25382540
private def transformTypedReduceExpression(

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ import org.apache.spark.sql.catalyst.InternalRow
2929
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedFunction, UnresolvedRelation}
3030
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProjection}
3131
import org.apache.spark.sql.catalyst.plans.logical
32-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
32+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
33+
import org.apache.spark.sql.catalyst.trees.TreePattern
3334
import org.apache.spark.sql.catalyst.types.DataTypeUtils
3435
import org.apache.spark.sql.classic.Dataset
3536
import org.apache.spark.sql.connect.SparkConnectTestUtils
@@ -922,4 +923,34 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
922923
assert(fn3.nameParts.head == "abcde")
923924
assert(fn3.isInternal)
924925
}
926+
927+
test("SPARK-51820 aggregate list should not contain UnresolvedOrdinal") {
928+
val ordinal = proto.Expression
929+
.newBuilder()
930+
.setLiteral(proto.Expression.Literal.newBuilder().setInteger(1).build())
931+
.build()
932+
933+
val sum =
934+
proto.Expression
935+
.newBuilder()
936+
.setUnresolvedFunction(
937+
proto.Expression.UnresolvedFunction
938+
.newBuilder()
939+
.setFunctionName("sum")
940+
.addArguments(ordinal))
941+
.build()
942+
943+
val aggregate = proto.Aggregate.newBuilder
944+
.setInput(readRel)
945+
.addAggregateExpressions(sum)
946+
.addGroupingExpressions(ordinal)
947+
.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
948+
.build()
949+
950+
val plan =
951+
transform(proto.Relation.newBuilder.setAggregate(aggregate).build()).asInstanceOf[Aggregate]
952+
953+
assert(plan.aggregateExpressions.forall(aggregateExpression =>
954+
!aggregateExpression.containsPattern(TreePattern.UNRESOLVED_ORDINAL)))
955+
}
925956
}

0 commit comments

Comments
 (0)