@@ -303,6 +303,13 @@ private[hive] case class HiveGenericUDTF(
303
303
* - `wrap()`/`wrapperFor()`: from 3 to 1
304
304
* - `unwrap()`/`unwrapperFor()`: from 1 to 3
305
305
* - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3
306
+ *
307
+ * Note that, Hive UDAF is initialized with aggregate mode, and some specific Hive UDAFs can't
308
+ * mix UPDATE and MERGE actions during its life cycle. However, Spark may do UPDATE on a UDAF and
309
+ * then do MERGE, in case of hash aggregate falling back to sort aggregate. To work around this
310
+ * issue, we track the ability to do MERGE in the Hive UDAF aggregate buffer. If Spark does
311
+ * UPDATE then MERGE, we can detect it and re-create the aggregate buffer with a different
312
+ * aggregate mode.
306
313
*/
307
314
private [hive] case class HiveUDAFFunction (
308
315
name : String ,
@@ -311,7 +318,7 @@ private[hive] case class HiveUDAFFunction(
311
318
isUDAFBridgeRequired : Boolean = false ,
312
319
mutableAggBufferOffset : Int = 0 ,
313
320
inputAggBufferOffset : Int = 0 )
314
- extends TypedImperativeAggregate [GenericUDAFEvaluator . AggregationBuffer ]
321
+ extends TypedImperativeAggregate [HiveUDAFBuffer ]
315
322
with HiveInspectors
316
323
with UserDefinedExpression {
317
324
@@ -352,29 +359,21 @@ private[hive] case class HiveUDAFFunction(
352
359
HiveEvaluator (evaluator, evaluator.init(GenericUDAFEvaluator .Mode .PARTIAL1 , inputInspectors))
353
360
}
354
361
355
- // The UDAF evaluator used to merge partial aggregation results.
362
+ // The UDAF evaluator used to consume partial aggregation results and produce final results.
363
+ // Hive `ObjectInspector` used to inspect final results.
356
364
@ transient
357
- private lazy val partial2ModeEvaluator = {
365
+ private lazy val finalHiveEvaluator = {
358
366
val evaluator = newEvaluator()
359
- evaluator.init(GenericUDAFEvaluator .Mode .PARTIAL2 , Array (partial1HiveEvaluator.objectInspector))
360
- evaluator
367
+ HiveEvaluator (
368
+ evaluator,
369
+ evaluator.init(GenericUDAFEvaluator .Mode .FINAL , Array (partial1HiveEvaluator.objectInspector)))
361
370
}
362
371
363
372
// Spark SQL data type of partial aggregation results
364
373
@ transient
365
374
private lazy val partialResultDataType =
366
375
inspectorToDataType(partial1HiveEvaluator.objectInspector)
367
376
368
- // The UDAF evaluator used to compute the final result from a partial aggregation result objects.
369
- // Hive `ObjectInspector` used to inspect the final aggregation result object.
370
- @ transient
371
- private lazy val finalHiveEvaluator = {
372
- val evaluator = newEvaluator()
373
- HiveEvaluator (
374
- evaluator,
375
- evaluator.init(GenericUDAFEvaluator .Mode .FINAL , Array (partial1HiveEvaluator.objectInspector)))
376
- }
377
-
378
377
// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
379
378
@ transient
380
379
private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
@@ -401,41 +400,74 @@ private[hive] case class HiveUDAFFunction(
401
400
s " $name( $distinct${children.map(_.sql).mkString(" , " )}) "
402
401
}
403
402
404
- override def createAggregationBuffer (): AggregationBuffer =
405
- partial1HiveEvaluator.evaluator.getNewAggregationBuffer
403
+ // The hive UDAF may create different buffers to handle different inputs: original data or
404
+ // aggregate buffer. However, the Spark UDAF framework does not expose this information when
405
+ // creating the buffer. Here we return null, and create the buffer in `update` and `merge`
406
+ // on demand, so that we can know what input we are dealing with.
407
+ override def createAggregationBuffer (): HiveUDAFBuffer = null
406
408
407
409
@ transient
408
410
private lazy val inputProjection = UnsafeProjection .create(children)
409
411
410
- override def update (buffer : AggregationBuffer , input : InternalRow ): AggregationBuffer = {
412
+ override def update (buffer : HiveUDAFBuffer , input : InternalRow ): HiveUDAFBuffer = {
413
+ // The input is original data, we create buffer with the partial1 evaluator.
414
+ val nonNullBuffer = if (buffer == null ) {
415
+ HiveUDAFBuffer (partial1HiveEvaluator.evaluator.getNewAggregationBuffer, false )
416
+ } else {
417
+ buffer
418
+ }
419
+
420
+ assert(! nonNullBuffer.canDoMerge, " can not call `merge` then `update` on a Hive UDAF." )
421
+
411
422
partial1HiveEvaluator.evaluator.iterate(
412
- buffer , wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
413
- buffer
423
+ nonNullBuffer.buf , wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
424
+ nonNullBuffer
414
425
}
415
426
416
- override def merge (buffer : AggregationBuffer , input : AggregationBuffer ): AggregationBuffer = {
427
+ override def merge (buffer : HiveUDAFBuffer , input : HiveUDAFBuffer ): HiveUDAFBuffer = {
428
+ // The input is aggregate buffer, we create buffer with the final evaluator.
429
+ val nonNullBuffer = if (buffer == null ) {
430
+ HiveUDAFBuffer (finalHiveEvaluator.evaluator.getNewAggregationBuffer, true )
431
+ } else {
432
+ buffer
433
+ }
434
+
435
+ // It's possible that we've called `update` of this Hive UDAF, and some specific Hive UDAF
436
+ // implementation can't mix the `update` and `merge` calls during its life cycle. To work
437
+ // around it, here we create a fresh buffer with final evaluator, and merge the existing buffer
438
+ // to it, and replace the existing buffer with it.
439
+ val mergeableBuf = if (! nonNullBuffer.canDoMerge) {
440
+ val newBuf = finalHiveEvaluator.evaluator.getNewAggregationBuffer
441
+ finalHiveEvaluator.evaluator.merge(
442
+ newBuf, partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer.buf))
443
+ HiveUDAFBuffer (newBuf, true )
444
+ } else {
445
+ nonNullBuffer
446
+ }
447
+
417
448
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
418
449
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
419
450
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
420
451
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
421
- partial2ModeEvaluator.merge(buffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
422
- buffer
452
+ finalHiveEvaluator.evaluator.merge(
453
+ mergeableBuf.buf, partial1HiveEvaluator.evaluator.terminatePartial(input.buf))
454
+ mergeableBuf
423
455
}
424
456
425
- override def eval (buffer : AggregationBuffer ): Any = {
426
- resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
457
+ override def eval (buffer : HiveUDAFBuffer ): Any = {
458
+ resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer.buf ))
427
459
}
428
460
429
- override def serialize (buffer : AggregationBuffer ): Array [Byte ] = {
461
+ override def serialize (buffer : HiveUDAFBuffer ): Array [Byte ] = {
430
462
// Serializes an `AggregationBuffer` that holds partial aggregation results so that we can
431
463
// shuffle it for global aggregation later.
432
- aggBufferSerDe.serialize(buffer)
464
+ aggBufferSerDe.serialize(buffer.buf )
433
465
}
434
466
435
- override def deserialize (bytes : Array [Byte ]): AggregationBuffer = {
467
+ override def deserialize (bytes : Array [Byte ]): HiveUDAFBuffer = {
436
468
// Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare
437
469
// for global aggregation by merging multiple partial aggregation results within a single group.
438
- aggBufferSerDe.deserialize(bytes)
470
+ HiveUDAFBuffer ( aggBufferSerDe.deserialize(bytes), false )
439
471
}
440
472
441
473
// Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
@@ -450,11 +482,19 @@ private[hive] case class HiveUDAFFunction(
450
482
private val mutableRow = new GenericInternalRow (1 )
451
483
452
484
def serialize (buffer : AggregationBuffer ): Array [Byte ] = {
485
+ // The buffer may be null if there is no input. It's unclear if the hive UDAF accepts null
486
+ // buffer, for safety we create an empty buffer here.
487
+ val nonNullBuffer = if (buffer == null ) {
488
+ partial1HiveEvaluator.evaluator.getNewAggregationBuffer
489
+ } else {
490
+ buffer
491
+ }
492
+
453
493
// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
454
494
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
455
495
// Then we can unwrap it to a Spark SQL value.
456
496
mutableRow.update(0 , partialResultUnwrapper(
457
- partial1HiveEvaluator.evaluator.terminatePartial(buffer )))
497
+ partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer )))
458
498
val unsafeRow = projection(mutableRow)
459
499
val bytes = ByteBuffer .allocate(unsafeRow.getSizeInBytes)
460
500
unsafeRow.writeTo(bytes)
@@ -466,12 +506,14 @@ private[hive] case class HiveUDAFFunction(
466
506
// returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The
467
507
// workaround here is creating an initial `AggregationBuffer` first and then merge the
468
508
// deserialized object into the buffer.
469
- val buffer = partial2ModeEvaluator .getNewAggregationBuffer
509
+ val buffer = finalHiveEvaluator.evaluator .getNewAggregationBuffer
470
510
val unsafeRow = new UnsafeRow (1 )
471
511
unsafeRow.pointTo(bytes, bytes.length)
472
512
val partialResult = unsafeRow.get(0 , partialResultDataType)
473
- partial2ModeEvaluator .merge(buffer, partialResultWrapper(partialResult))
513
+ finalHiveEvaluator.evaluator .merge(buffer, partialResultWrapper(partialResult))
474
514
buffer
475
515
}
476
516
}
477
517
}
518
+
519
+ case class HiveUDAFBuffer (buf : AggregationBuffer , canDoMerge : Boolean )
0 commit comments