@@ -342,12 +342,10 @@ case class HashAggregateExec(
342
342
// can't bind the `mergeExpressions` with the output of the partial aggregate, as they use
343
343
// the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead,
344
344
// we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`.
345
- val aggAttrs = aggregateExpressions.map(_.aggregateFunction)
345
+ val aggAttrs = aggregateExpressions
346
+ .filter(a => a.mode == Final || ! a.isDistinct).map(_.aggregateFunction)
346
347
.flatMap(_.inputAggBufferAttributes)
347
- val distinctAttrs = child.output.filterNot(
348
- a => (groupingAttributes ++ aggAttrs).exists(_.name == a.name))
349
- // the order is consistent with `AggUtils.planAggregateWithOneDistinct`
350
- groupingAttributes ++ distinctAttrs ++ aggAttrs
348
+ child.output.dropRight(aggAttrs.length) ++ aggAttrs
351
349
} else {
352
350
child.output
353
351
}
@@ -870,9 +868,9 @@ case class HashAggregateExec(
870
868
private def doConsumeWithKeys (ctx : CodegenContext , input : Seq [ExprCode ]): String = {
871
869
// create grouping key
872
870
val unsafeRowKeyCode = GenerateUnsafeProjection .createCode(
873
- ctx, bindReferences[Expression ](groupingExpressions, inputAttributes ))
871
+ ctx, bindReferences[Expression ](groupingExpressions, child.output ))
874
872
val fastRowKeys = ctx.generateExpressions(
875
- bindReferences[Expression ](groupingExpressions, inputAttributes ))
873
+ bindReferences[Expression ](groupingExpressions, child.output ))
876
874
val unsafeRowKeys = unsafeRowKeyCode.value
877
875
val unsafeRowKeyHash = ctx.freshName(" unsafeRowKeyHash" )
878
876
val unsafeRowBuffer = ctx.freshName(" unsafeRowAggBuffer" )
0 commit comments