Skip to content

Commit 3fd39c5

Browse files
committed
address comment
1 parent 9a0b788 commit 3fd39c5

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,12 +342,10 @@ case class HashAggregateExec(
342342
// can't bind the `mergeExpressions` with the output of the partial aggregate, as they use
343343
// the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead,
344344
// we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`.
345-
val aggAttrs = aggregateExpressions.map(_.aggregateFunction)
345+
val aggAttrs = aggregateExpressions
346+
.filter(a => a.mode == Final || !a.isDistinct).map(_.aggregateFunction)
346347
.flatMap(_.inputAggBufferAttributes)
347-
val distinctAttrs = child.output.filterNot(
348-
a => (groupingAttributes ++ aggAttrs).exists(_.name == a.name))
349-
// the order is consistent with `AggUtils.planAggregateWithOneDistinct`
350-
groupingAttributes ++ distinctAttrs ++ aggAttrs
348+
child.output.dropRight(aggAttrs.length) ++ aggAttrs
351349
} else {
352350
child.output
353351
}
@@ -870,9 +868,9 @@ case class HashAggregateExec(
870868
private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
871869
// create grouping key
872870
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
873-
ctx, bindReferences[Expression](groupingExpressions, inputAttributes))
871+
ctx, bindReferences[Expression](groupingExpressions, child.output))
874872
val fastRowKeys = ctx.generateExpressions(
875-
bindReferences[Expression](groupingExpressions, inputAttributes))
873+
bindReferences[Expression](groupingExpressions, child.output))
876874
val unsafeRowKeys = unsafeRowKeyCode.value
877875
val unsafeRowKeyHash = ctx.freshName("unsafeRowKeyHash")
878876
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ class DataFrameAggregateSuite extends QueryTest
975975
}
976976

977977
Seq(true, false).foreach { value =>
978-
test(s"SPARK-31620: agg with subquery (codegen = $value)") {
978+
test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
979979
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
980980
withTempView("t1", "t2") {
981981
sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)")

0 commit comments

Comments
 (0)