Skip to content

Commit df330fa

Browse files
author
pgandhi
committed
[SPARK-27207] : Coming up with a unit test for custom UDAF
1 parent db46cf7 commit df330fa

File tree

2 files changed

+102
-30
lines changed

2 files changed

+102
-30
lines changed

sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql
2020
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
2121

2222
import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
23+
import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax2
2324
import org.apache.spark.sql.catalyst.InternalRow
2425
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow}
2526
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
@@ -210,6 +211,20 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
210211
checkAnswer(query, expected)
211212
}
212213

214+
test("SPARK-27207: Ensure aggregate buffers are initialized again for SortBasedAggregate") {
215+
withSQLConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold" -> "5") {
216+
val df = data.toDF("value", "key").coalesce(2)
217+
val query = df.groupBy($"key").agg(typedMax2($"value"), count($"value"), typedMax2($"value"))
218+
val expected = data.groupBy(_._2).toSeq.map { group =>
219+
val (key, values) = group
220+
val valueMax = values.map(_._1).max
221+
val countValue = values.size
222+
Row(key, valueMax, countValue, valueMax)
223+
}
224+
checkAnswer(query, expected)
225+
}
226+
}
227+
213228
private def typedMax(column: Column): Column = {
214229
val max = TypedMax(column.expr, nullable = false)
215230
Column(max.toAggregateExpression())
@@ -219,6 +234,11 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
219234
val max = TypedMax(column.expr, nullable = true)
220235
Column(max.toAggregateExpression())
221236
}
237+
238+
private def typedMax2(column: Column): Column = {
239+
val max = TypedMax2(column.expr, nullable = false)
240+
Column(max.toAggregateExpression())
241+
}
222242
}
223243

224244
object TypedImperativeAggregateSuite {
@@ -299,5 +319,87 @@ object TypedImperativeAggregateSuite {
299319
}
300320
}
301321

322+
/**
323+
* Calculate the max value with object aggregation buffer. This stores class MaxValue
324+
* in aggregation buffer.
325+
*/
326+
private case class TypedMax2(
327+
child: Expression,
328+
nullable: Boolean = false,
329+
mutableAggBufferOffset: Int = 0,
330+
inputAggBufferOffset: Int = 0)
331+
extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes {
332+
333+
334+
var maxValueBuffer: MaxValue = null
335+
override def createAggregationBuffer(): MaxValue = {
336+
// Returns Int.MinValue if all inputs are null
337+
maxValueBuffer = new MaxValue(Int.MinValue)
338+
maxValueBuffer
339+
}
340+
341+
override def update(buffer: MaxValue, input: InternalRow): MaxValue = {
342+
child.eval(input) match {
343+
case inputValue: Int =>
344+
if (inputValue > buffer.value) {
345+
buffer.value = inputValue
346+
buffer.isValueSet = true
347+
}
348+
case null => // skip
349+
}
350+
buffer
351+
}
352+
353+
override def merge(bufferMax: MaxValue, inputMax: MaxValue): MaxValue = {
354+
// The below if condition will throw a Null Pointer Exception if initialize() is not called
355+
if (maxValueBuffer.isValueSet) {
356+
// do nothing
357+
}
358+
if (inputMax.value > bufferMax.value) {
359+
bufferMax.value = inputMax.value
360+
bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet
361+
}
362+
bufferMax
363+
}
364+
365+
override def eval(bufferMax: MaxValue): Any = {
366+
if (nullable && bufferMax.isValueSet == false) {
367+
null
368+
} else {
369+
bufferMax.value
370+
}
371+
}
372+
373+
override lazy val deterministic: Boolean = true
374+
375+
override def children: Seq[Expression] = Seq(child)
376+
377+
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
378+
379+
override def dataType: DataType = IntegerType
380+
381+
override def withNewMutableAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] =
382+
copy(mutableAggBufferOffset = newOffset)
383+
384+
override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] =
385+
copy(inputAggBufferOffset = newOffset)
386+
387+
override def serialize(buffer: MaxValue): Array[Byte] = {
388+
val out = new ByteArrayOutputStream()
389+
val stream = new DataOutputStream(out)
390+
stream.writeBoolean(buffer.isValueSet)
391+
stream.writeInt(buffer.value)
392+
out.toByteArray
393+
}
394+
395+
override def deserialize(storageFormat: Array[Byte]): MaxValue = {
396+
val in = new ByteArrayInputStream(storageFormat)
397+
val stream = new DataInputStream(in)
398+
val isValueSet = stream.readBoolean()
399+
val value = stream.readInt()
400+
new MaxValue(value, isValueSet)
401+
}
402+
}
403+
302404
private class MaxValue(var value: Int, var isValueSet: Boolean = false)
303405
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,6 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
4949
(2: Integer) -> null,
5050
(3: Integer) -> null
5151
).toDF("key", "value").repartition(2).createOrReplaceTempView("t")
52-
Seq(
53-
(0: Integer) -> "val_0",
54-
(1: Integer) -> "val_1",
55-
(2: Integer) -> "val_2",
56-
(3: Integer) -> "val_3",
57-
(4: Integer) -> "val_4",
58-
(5: Integer) -> "val_5",
59-
(6: Integer) -> null,
60-
(7: Integer) -> null
61-
).toDF("key", "value").repartition(2).createOrReplaceTempView("t2")
6252
}
6353

6454
protected override def afterAll(): Unit = {
@@ -133,26 +123,6 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
133123
}
134124
}
135125

136-
test("SPARK-27207: customized Hive UDAF with two aggregation buffers for Sort" +
137-
" Based Aggregation") {
138-
withSQLConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold" -> "2") {
139-
val df = sql("SELECT key % 2, mock2(value) FROM t2 GROUP BY key % 2")
140-
141-
val aggs = df.queryExecution.executedPlan.collect {
142-
case agg: ObjectHashAggregateExec => agg
143-
}
144-
145-
// There should be two aggregate operators, one for partial aggregation, and the other for
146-
// global aggregation.
147-
assert(aggs.length == 2)
148-
149-
checkAnswer(df, Seq(
150-
Row(0, Row(3, 1)),
151-
Row(1, Row(3, 1))
152-
))
153-
}
154-
}
155-
156126
test("call JAVA UDAF") {
157127
withTempView("temp") {
158128
withUserDefinedFunction("myDoubleAvg" -> false) {

0 commit comments

Comments
 (0)