@@ -2457,11 +2457,13 @@ class SparkConnectPlanner(
2457
2457
input
2458
2458
}
2459
2459
2460
- val groupingExpressionsWithOrdinals = rel.getGroupingExpressionsList.asScala.toSeq
2461
- .map(transformGroupingExpressionAndReplaceOrdinals)
2460
+ val groupingExpressions =
2461
+ rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression)
2462
+ val groupingExpressionsWithOrdinals =
2463
+ groupingExpressions.map(replaceOrdinalsInGroupingExpressions)
2462
2464
val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq
2463
2465
.map(expr => transformExpressionWithTypedReduceExpression(expr, logicalPlan))
2464
- val aliasedAgg = (groupingExpressionsWithOrdinals ++ aggExprs).map(toNamedExpression)
2466
+ val aliasedAgg = (groupingExpressions ++ aggExprs).map(toNamedExpression)
2465
2467
2466
2468
rel.getGroupType match {
2467
2469
case proto.Aggregate .GroupType .GROUP_TYPE_GROUPBY =>
@@ -2506,7 +2508,10 @@ class SparkConnectPlanner(
2506
2508
val groupingSetsExpressionsWithOrdinals =
2507
2509
rel.getGroupingSetsList.asScala.toSeq.map { getGroupingSets =>
2508
2510
getGroupingSets.getGroupingSetList.asScala.toSeq
2509
- .map(transformGroupingExpressionAndReplaceOrdinals)
2511
+ .map(groupingExpressions => {
2512
+ val transformedGroupingExpression = transformExpression(groupingExpressions)
2513
+ replaceOrdinalsInGroupingExpressions(transformedGroupingExpression)
2514
+ })
2510
2515
}
2511
2516
logical.Aggregate (
2512
2517
groupingExpressions = Seq (
@@ -2521,18 +2526,15 @@ class SparkConnectPlanner(
2521
2526
}
2522
2527
2523
2528
/**
2524
- * Transforms an input protobuf grouping expression into the Catalyst expression and converts
2525
- * top-level integer [[ Literal ]]s to [[ UnresolvedOrdinal ]]s, if `groupByOrdinal` is enabled.
2529
+ * Replaces top-level integer [[ Literal ]]s to [[ UnresolvedOrdinal ]]s, if `groupByOrdinal` is
2530
+ * enabled.
2526
2531
*/
2527
- private def transformGroupingExpressionAndReplaceOrdinals (
2528
- groupingExpression : proto.Expression ) = {
2529
- val transformedGroupingExpression = transformExpression(groupingExpression)
2532
+ private def replaceOrdinalsInGroupingExpressions (groupingExpression : Expression ) =
2530
2533
if (session.sessionState.conf.groupByOrdinal) {
2531
- replaceIntegerLiteralWithOrdinal(transformedGroupingExpression )
2534
+ replaceIntegerLiteralWithOrdinal(groupingExpression )
2532
2535
} else {
2533
- transformedGroupingExpression
2536
+ groupingExpression
2534
2537
}
2535
- }
2536
2538
2537
2539
@ deprecated(" TypedReduce is now implemented using a normal UDAF aggregator." , " 4.0.0" )
2538
2540
private def transformTypedReduceExpression (
0 commit comments