Skip to content

Commit 3d49bd4

Browse files
committed
[SPARK-24935][SQL][FOLLOWUP] support INIT -> UPDATE -> MERGE -> FINISH in Hive UDAF adapter
## What changes were proposed in this pull request? This is a followup of #24144 . #24144 missed one case: when hash aggregate fallback to sort aggregate, the life cycle of UDAF is: INIT -> UPDATE -> MERGE -> FINISH. However, not all Hive UDAF can support it. Hive UDAF knows the aggregation mode when creating the aggregation buffer, so that it can create different buffers for different inputs: the original data or the aggregation buffer. Please see an example in the [sketches library](https://github.com/DataSketches/sketches-hive/blob/7f9e76e9e03807277146291beb2c7bec40e8672b/src/main/java/com/yahoo/sketches/hive/cpc/DataToSketchUDAF.java#L107). The buffer for UPDATE may not support MERGE. This PR updates the Hive UDAF adapter in Spark to support INIT -> UPDATE -> MERGE -> FINISH, by turning it to INIT -> UPDATE -> FINISH + IINIT -> MERGE -> FINISH. ## How was this patch tested? a new test case Closes #24459 from cloud-fan/hive-udaf. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit 7432e7d) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent ba9e12d commit 3d49bd4

File tree

2 files changed

+64
-28
lines changed

2 files changed

+64
-28
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,13 @@ private[hive] case class HiveGenericUDTF(
303303
* - `wrap()`/`wrapperFor()`: from 3 to 1
304304
* - `unwrap()`/`unwrapperFor()`: from 1 to 3
305305
* - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3
306+
*
307+
* Note that, Hive UDAF is initialized with aggregate mode, and some specific Hive UDAFs can't
308+
* mix UPDATE and MERGE actions during its life cycle. However, Spark may do UPDATE on a UDAF and
309+
* then do MERGE, in case of hash aggregate falling back to sort aggregate. To work around this
310+
* issue, we track the ability to do MERGE in the Hive UDAF aggregate buffer. If Spark does
311+
* UPDATE then MERGE, we can detect it and re-create the aggregate buffer with a different
312+
* aggregate mode.
306313
*/
307314
private[hive] case class HiveUDAFFunction(
308315
name: String,
@@ -311,7 +318,7 @@ private[hive] case class HiveUDAFFunction(
311318
isUDAFBridgeRequired: Boolean = false,
312319
mutableAggBufferOffset: Int = 0,
313320
inputAggBufferOffset: Int = 0)
314-
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
321+
extends TypedImperativeAggregate[HiveUDAFBuffer]
315322
with HiveInspectors
316323
with UserDefinedExpression {
317324

@@ -397,55 +404,70 @@ private[hive] case class HiveUDAFFunction(
397404
// aggregate buffer. However, the Spark UDAF framework does not expose this information when
398405
// creating the buffer. Here we return null, and create the buffer in `update` and `merge`
399406
// on demand, so that we can know what input we are dealing with.
400-
override def createAggregationBuffer(): AggregationBuffer = null
407+
override def createAggregationBuffer(): HiveUDAFBuffer = null
401408

402409
@transient
403410
private lazy val inputProjection = UnsafeProjection.create(children)
404411

405-
override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
412+
override def update(buffer: HiveUDAFBuffer, input: InternalRow): HiveUDAFBuffer = {
406413
// The input is original data, we create buffer with the partial1 evaluator.
407414
val nonNullBuffer = if (buffer == null) {
408-
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
415+
HiveUDAFBuffer(partial1HiveEvaluator.evaluator.getNewAggregationBuffer, false)
409416
} else {
410417
buffer
411418
}
412419

420+
assert(!nonNullBuffer.canDoMerge, "can not call `merge` then `update` on a Hive UDAF.")
421+
413422
partial1HiveEvaluator.evaluator.iterate(
414-
nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
423+
nonNullBuffer.buf, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
415424
nonNullBuffer
416425
}
417426

418-
override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = {
427+
override def merge(buffer: HiveUDAFBuffer, input: HiveUDAFBuffer): HiveUDAFBuffer = {
419428
// The input is aggregate buffer, we create buffer with the final evaluator.
420429
val nonNullBuffer = if (buffer == null) {
421-
finalHiveEvaluator.evaluator.getNewAggregationBuffer
430+
HiveUDAFBuffer(finalHiveEvaluator.evaluator.getNewAggregationBuffer, true)
422431
} else {
423432
buffer
424433
}
425434

435+
// It's possible that we've called `update` of this Hive UDAF, and some specific Hive UDAF
436+
// implementation can't mix the `update` and `merge` calls during its life cycle. To work
437+
// around it, here we create a fresh buffer with final evaluator, and merge the existing buffer
438+
// to it, and replace the existing buffer with it.
439+
val mergeableBuf = if (!nonNullBuffer.canDoMerge) {
440+
val newBuf = finalHiveEvaluator.evaluator.getNewAggregationBuffer
441+
finalHiveEvaluator.evaluator.merge(
442+
newBuf, partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer.buf))
443+
HiveUDAFBuffer(newBuf, true)
444+
} else {
445+
nonNullBuffer
446+
}
447+
426448
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
427449
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
428450
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
429451
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
430452
finalHiveEvaluator.evaluator.merge(
431-
nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
432-
nonNullBuffer
453+
mergeableBuf.buf, partial1HiveEvaluator.evaluator.terminatePartial(input.buf))
454+
mergeableBuf
433455
}
434456

435-
override def eval(buffer: AggregationBuffer): Any = {
436-
resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
457+
override def eval(buffer: HiveUDAFBuffer): Any = {
458+
resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer.buf))
437459
}
438460

439-
override def serialize(buffer: AggregationBuffer): Array[Byte] = {
461+
override def serialize(buffer: HiveUDAFBuffer): Array[Byte] = {
440462
// Serializes an `AggregationBuffer` that holds partial aggregation results so that we can
441463
// shuffle it for global aggregation later.
442-
aggBufferSerDe.serialize(buffer)
464+
aggBufferSerDe.serialize(buffer.buf)
443465
}
444466

445-
override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
467+
override def deserialize(bytes: Array[Byte]): HiveUDAFBuffer = {
446468
// Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare
447469
// for global aggregation by merging multiple partial aggregation results within a single group.
448-
aggBufferSerDe.deserialize(bytes)
470+
HiveUDAFBuffer(aggBufferSerDe.deserialize(bytes), false)
449471
}
450472

451473
// Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
@@ -493,3 +515,5 @@ private[hive] case class HiveUDAFFunction(
493515
}
494516
}
495517
}
518+
519+
case class HiveUDAFBuffer(buf: AggregationBuffer, canDoMerge: Boolean)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
2828
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
2929
import test.org.apache.spark.sql.MyDoubleAvg
3030

31-
import org.apache.spark.SparkException
3231
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
3332
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
3433
import org.apache.spark.sql.hive.test.TestHiveSingleton
34+
import org.apache.spark.sql.internal.SQLConf
3535
import org.apache.spark.sql.test.SQLTestUtils
3636

3737
class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
@@ -93,21 +93,33 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
9393
))
9494
}
9595

96-
test("customized Hive UDAF with two aggregation buffers") {
97-
val df = sql("SELECT key % 2, mock2(value) FROM t GROUP BY key % 2")
96+
test("SPARK-24935: customized Hive UDAF with two aggregation buffers") {
97+
withTempView("v") {
98+
spark.range(100).createTempView("v")
99+
val df = sql("SELECT id % 2, mock2(id) FROM v GROUP BY id % 2")
98100

99-
val aggs = df.queryExecution.executedPlan.collect {
100-
case agg: ObjectHashAggregateExec => agg
101-
}
101+
val aggs = df.queryExecution.executedPlan.collect {
102+
case agg: ObjectHashAggregateExec => agg
103+
}
102104

103-
// There should be two aggregate operators, one for partial aggregation, and the other for
104-
// global aggregation.
105-
assert(aggs.length == 2)
105+
// There should be two aggregate operators, one for partial aggregation, and the other for
106+
// global aggregation.
107+
assert(aggs.length == 2)
106108

107-
checkAnswer(df, Seq(
108-
Row(0, Row(1, 1)),
109-
Row(1, Row(1, 1))
110-
))
109+
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1") {
110+
checkAnswer(df, Seq(
111+
Row(0, Row(50, 0)),
112+
Row(1, Row(50, 0))
113+
))
114+
}
115+
116+
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
117+
checkAnswer(df, Seq(
118+
Row(0, Row(50, 0)),
119+
Row(1, Row(50, 0))
120+
))
121+
}
122+
}
111123
}
112124

113125
test("call JAVA UDAF") {

0 commit comments

Comments
 (0)