Skip to content

Commit 85fc2f2

Browse files
pgandhicloud-fan
authored andcommitted
[SPARK-24935][SQL][2.3] fix Hive UDAF with two aggregation buffers
## What changes were proposed in this pull request? backport #24144 and #24459 to 2.3. ## How was this patch tested? existing tests Closes #24539 from cloud-fan/backport. Lead-authored-by: pgandhi <pgandhi@verizonmedia.com> Co-authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
1 parent 52daf49 commit 85fc2f2

File tree

2 files changed

+192
-32
lines changed

2 files changed

+192
-32
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,13 @@ private[hive] case class HiveGenericUDTF(
303303
* - `wrap()`/`wrapperFor()`: from 3 to 1
304304
* - `unwrap()`/`unwrapperFor()`: from 1 to 3
305305
* - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3
306+
*
307+
* Note that, Hive UDAF is initialized with aggregate mode, and some specific Hive UDAFs can't
308+
* mix UPDATE and MERGE actions during its life cycle. However, Spark may do UPDATE on a UDAF and
309+
* then do MERGE, in case of hash aggregate falling back to sort aggregate. To work around this
310+
* issue, we track the ability to do MERGE in the Hive UDAF aggregate buffer. If Spark does
311+
* UPDATE then MERGE, we can detect it and re-create the aggregate buffer with a different
312+
* aggregate mode.
306313
*/
307314
private[hive] case class HiveUDAFFunction(
308315
name: String,
@@ -311,7 +318,7 @@ private[hive] case class HiveUDAFFunction(
311318
isUDAFBridgeRequired: Boolean = false,
312319
mutableAggBufferOffset: Int = 0,
313320
inputAggBufferOffset: Int = 0)
314-
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
321+
extends TypedImperativeAggregate[HiveUDAFBuffer]
315322
with HiveInspectors
316323
with UserDefinedExpression {
317324

@@ -352,29 +359,21 @@ private[hive] case class HiveUDAFFunction(
352359
HiveEvaluator(evaluator, evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors))
353360
}
354361

355-
// The UDAF evaluator used to merge partial aggregation results.
362+
// The UDAF evaluator used to consume partial aggregation results and produce final results.
363+
// Hive `ObjectInspector` used to inspect final results.
356364
@transient
357-
private lazy val partial2ModeEvaluator = {
365+
private lazy val finalHiveEvaluator = {
358366
val evaluator = newEvaluator()
359-
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partial1HiveEvaluator.objectInspector))
360-
evaluator
367+
HiveEvaluator(
368+
evaluator,
369+
evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
361370
}
362371

363372
// Spark SQL data type of partial aggregation results
364373
@transient
365374
private lazy val partialResultDataType =
366375
inspectorToDataType(partial1HiveEvaluator.objectInspector)
367376

368-
// The UDAF evaluator used to compute the final result from a partial aggregation result objects.
369-
// Hive `ObjectInspector` used to inspect the final aggregation result object.
370-
@transient
371-
private lazy val finalHiveEvaluator = {
372-
val evaluator = newEvaluator()
373-
HiveEvaluator(
374-
evaluator,
375-
evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
376-
}
377-
378377
// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
379378
@transient
380379
private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
@@ -401,41 +400,74 @@ private[hive] case class HiveUDAFFunction(
401400
s"$name($distinct${children.map(_.sql).mkString(", ")})"
402401
}
403402

404-
override def createAggregationBuffer(): AggregationBuffer =
405-
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
403+
// The hive UDAF may create different buffers to handle different inputs: original data or
404+
// aggregate buffer. However, the Spark UDAF framework does not expose this information when
405+
// creating the buffer. Here we return null, and create the buffer in `update` and `merge`
406+
// on demand, so that we can know what input we are dealing with.
407+
override def createAggregationBuffer(): HiveUDAFBuffer = null
406408

407409
@transient
408410
private lazy val inputProjection = UnsafeProjection.create(children)
409411

410-
override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
412+
override def update(buffer: HiveUDAFBuffer, input: InternalRow): HiveUDAFBuffer = {
413+
// The input is original data, we create buffer with the partial1 evaluator.
414+
val nonNullBuffer = if (buffer == null) {
415+
HiveUDAFBuffer(partial1HiveEvaluator.evaluator.getNewAggregationBuffer, false)
416+
} else {
417+
buffer
418+
}
419+
420+
assert(!nonNullBuffer.canDoMerge, "can not call `merge` then `update` on a Hive UDAF.")
421+
411422
partial1HiveEvaluator.evaluator.iterate(
412-
buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
413-
buffer
423+
nonNullBuffer.buf, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
424+
nonNullBuffer
414425
}
415426

416-
override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = {
427+
override def merge(buffer: HiveUDAFBuffer, input: HiveUDAFBuffer): HiveUDAFBuffer = {
428+
// The input is aggregate buffer, we create buffer with the final evaluator.
429+
val nonNullBuffer = if (buffer == null) {
430+
HiveUDAFBuffer(finalHiveEvaluator.evaluator.getNewAggregationBuffer, true)
431+
} else {
432+
buffer
433+
}
434+
435+
// It's possible that we've called `update` of this Hive UDAF, and some specific Hive UDAF
436+
// implementation can't mix the `update` and `merge` calls during its life cycle. To work
437+
// around it, here we create a fresh buffer with final evaluator, and merge the existing buffer
438+
// to it, and replace the existing buffer with it.
439+
val mergeableBuf = if (!nonNullBuffer.canDoMerge) {
440+
val newBuf = finalHiveEvaluator.evaluator.getNewAggregationBuffer
441+
finalHiveEvaluator.evaluator.merge(
442+
newBuf, partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer.buf))
443+
HiveUDAFBuffer(newBuf, true)
444+
} else {
445+
nonNullBuffer
446+
}
447+
417448
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
418449
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
419450
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
420451
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
421-
partial2ModeEvaluator.merge(buffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
422-
buffer
452+
finalHiveEvaluator.evaluator.merge(
453+
mergeableBuf.buf, partial1HiveEvaluator.evaluator.terminatePartial(input.buf))
454+
mergeableBuf
423455
}
424456

425-
override def eval(buffer: AggregationBuffer): Any = {
426-
resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
457+
override def eval(buffer: HiveUDAFBuffer): Any = {
458+
resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer.buf))
427459
}
428460

429-
override def serialize(buffer: AggregationBuffer): Array[Byte] = {
461+
override def serialize(buffer: HiveUDAFBuffer): Array[Byte] = {
430462
// Serializes an `AggregationBuffer` that holds partial aggregation results so that we can
431463
// shuffle it for global aggregation later.
432-
aggBufferSerDe.serialize(buffer)
464+
aggBufferSerDe.serialize(buffer.buf)
433465
}
434466

435-
override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
467+
override def deserialize(bytes: Array[Byte]): HiveUDAFBuffer = {
436468
// Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare
437469
// for global aggregation by merging multiple partial aggregation results within a single group.
438-
aggBufferSerDe.deserialize(bytes)
470+
HiveUDAFBuffer(aggBufferSerDe.deserialize(bytes), false)
439471
}
440472

441473
// Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
@@ -450,11 +482,19 @@ private[hive] case class HiveUDAFFunction(
450482
private val mutableRow = new GenericInternalRow(1)
451483

452484
def serialize(buffer: AggregationBuffer): Array[Byte] = {
485+
// The buffer may be null if there is no input. It's unclear if the hive UDAF accepts null
486+
// buffer, for safety we create an empty buffer here.
487+
val nonNullBuffer = if (buffer == null) {
488+
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
489+
} else {
490+
buffer
491+
}
492+
453493
// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
454494
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
455495
// Then we can unwrap it to a Spark SQL value.
456496
mutableRow.update(0, partialResultUnwrapper(
457-
partial1HiveEvaluator.evaluator.terminatePartial(buffer)))
497+
partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer)))
458498
val unsafeRow = projection(mutableRow)
459499
val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
460500
unsafeRow.writeTo(bytes)
@@ -466,12 +506,14 @@ private[hive] case class HiveUDAFFunction(
466506
// returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The
467507
// workaround here is creating an initial `AggregationBuffer` first and then merge the
468508
// deserialized object into the buffer.
469-
val buffer = partial2ModeEvaluator.getNewAggregationBuffer
509+
val buffer = finalHiveEvaluator.evaluator.getNewAggregationBuffer
470510
val unsafeRow = new UnsafeRow(1)
471511
unsafeRow.pointTo(bytes, bytes.length)
472512
val partialResult = unsafeRow.get(0, partialResultDataType)
473-
partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult))
513+
finalHiveEvaluator.evaluator.merge(buffer, partialResultWrapper(partialResult))
474514
buffer
475515
}
476516
}
477517
}
518+
519+
case class HiveUDAFBuffer(buf: AggregationBuffer, canDoMerge: Boolean)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import test.org.apache.spark.sql.MyDoubleAvg
3131
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
3232
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
3333
import org.apache.spark.sql.hive.test.TestHiveSingleton
34+
import org.apache.spark.sql.internal.SQLConf
3435
import org.apache.spark.sql.test.SQLTestUtils
3536

3637
class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
@@ -39,6 +40,7 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
3940
protected override def beforeAll(): Unit = {
4041
sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'")
4142
sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
43+
sql(s"CREATE TEMPORARY FUNCTION mock2 AS '${classOf[MockUDAF2].getName}'")
4244

4345
Seq(
4446
(0: Integer) -> "val_0",
@@ -91,6 +93,35 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
9193
))
9294
}
9395

96+
test("SPARK-24935: customized Hive UDAF with two aggregation buffers") {
97+
withTempView("v") {
98+
spark.range(100).createTempView("v")
99+
val df = sql("SELECT id % 2, mock2(id) FROM v GROUP BY id % 2")
100+
101+
val aggs = df.queryExecution.executedPlan.collect {
102+
case agg: ObjectHashAggregateExec => agg
103+
}
104+
105+
// There should be two aggregate operators, one for partial aggregation, and the other for
106+
// global aggregation.
107+
assert(aggs.length == 2)
108+
109+
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1") {
110+
checkAnswer(df, Seq(
111+
Row(0, Row(50, 0)),
112+
Row(1, Row(50, 0))
113+
))
114+
}
115+
116+
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
117+
checkAnswer(df, Seq(
118+
Row(0, Row(50, 0)),
119+
Row(1, Row(50, 0))
120+
))
121+
}
122+
}
123+
}
124+
94125
test("call JAVA UDAF") {
95126
withTempView("temp") {
96127
withUserDefinedFunction("myDoubleAvg" -> false) {
@@ -126,12 +157,22 @@ class MockUDAF extends AbstractGenericUDAFResolver {
126157
override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator
127158
}
128159

160+
class MockUDAF2 extends AbstractGenericUDAFResolver {
161+
override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator2
162+
}
163+
129164
class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long)
130165
extends GenericUDAFEvaluator.AbstractAggregationBuffer {
131166

132167
override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
133168
}
134169

170+
class MockUDAFBuffer2(var nonNullCount: Long, var nullCount: Long)
171+
extends GenericUDAFEvaluator.AbstractAggregationBuffer {
172+
173+
override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
174+
}
175+
135176
class MockUDAFEvaluator extends GenericUDAFEvaluator {
136177
private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
137178

@@ -183,3 +224,80 @@ class MockUDAFEvaluator extends GenericUDAFEvaluator {
183224

184225
override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg)
185226
}
227+
228+
// Same as MockUDAFEvaluator but using two aggregation buffers, one for PARTIAL1 and the other
229+
// for PARTIAL2.
230+
class MockUDAFEvaluator2 extends GenericUDAFEvaluator {
231+
private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
232+
233+
private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
234+
private var aggMode: Mode = null
235+
236+
private val bufferOI = {
237+
val fieldNames = Seq("nonNullCount", "nullCount").asJava
238+
val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava
239+
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
240+
}
241+
242+
private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount")
243+
244+
private val nullCountField = bufferOI.getStructFieldRef("nullCount")
245+
246+
override def getNewAggregationBuffer: AggregationBuffer = {
247+
// These 2 modes consume original data.
248+
if (aggMode == Mode.PARTIAL1 || aggMode == Mode.COMPLETE) {
249+
new MockUDAFBuffer(0L, 0L)
250+
} else {
251+
new MockUDAFBuffer2(0L, 0L)
252+
}
253+
}
254+
255+
override def reset(agg: AggregationBuffer): Unit = {
256+
val buffer = agg.asInstanceOf[MockUDAFBuffer]
257+
buffer.nonNullCount = 0L
258+
buffer.nullCount = 0L
259+
}
260+
261+
override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = {
262+
aggMode = mode
263+
bufferOI
264+
}
265+
266+
override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = {
267+
val buffer = agg.asInstanceOf[MockUDAFBuffer]
268+
if (parameters.head eq null) {
269+
buffer.nullCount += 1L
270+
} else {
271+
buffer.nonNullCount += 1L
272+
}
273+
}
274+
275+
override def merge(agg: AggregationBuffer, partial: Object): Unit = {
276+
if (partial ne null) {
277+
val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField))
278+
val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField))
279+
val buffer = agg.asInstanceOf[MockUDAFBuffer2]
280+
buffer.nonNullCount += nonNullCount
281+
buffer.nullCount += nullCount
282+
}
283+
}
284+
285+
// As this method is called for both states, Partial1 and Partial2, the hack in the method
286+
// to check for class of aggregation buffer was necessary.
287+
override def terminatePartial(agg: AggregationBuffer): AnyRef = {
288+
var result: AnyRef = null
289+
if (agg.getClass.toString.contains("MockUDAFBuffer2")) {
290+
val buffer = agg.asInstanceOf[MockUDAFBuffer2]
291+
result = Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
292+
} else {
293+
val buffer = agg.asInstanceOf[MockUDAFBuffer]
294+
result = Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
295+
}
296+
result
297+
}
298+
299+
override def terminate(agg: AggregationBuffer): AnyRef = {
300+
val buffer = agg.asInstanceOf[MockUDAFBuffer2]
301+
Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
302+
}
303+
}

0 commit comments

Comments
 (0)