@@ -20,6 +20,7 @@ package org.apache.spark.sql
20
20
import java .io .{ByteArrayInputStream , ByteArrayOutputStream , DataInputStream , DataOutputStream }
21
21
22
22
import org .apache .spark .sql .TypedImperativeAggregateSuite .TypedMax
23
+ import org .apache .spark .sql .TypedImperativeAggregateSuite .TypedMax2
23
24
import org .apache .spark .sql .catalyst .InternalRow
24
25
import org .apache .spark .sql .catalyst .expressions .{BoundReference , Expression , GenericInternalRow , ImplicitCastInputTypes , SpecificInternalRow }
25
26
import org .apache .spark .sql .catalyst .expressions .aggregate .TypedImperativeAggregate
@@ -210,6 +211,20 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
210
211
checkAnswer(query, expected)
211
212
}
212
213
214
+ test(" SPARK-27207: Ensure aggregate buffers are initialized again for SortBasedAggregate" ) {
215
+ withSQLConf(" spark.sql.objectHashAggregate.sortBased.fallbackThreshold" -> " 5" ) {
216
+ val df = data.toDF(" value" , " key" ).coalesce(2 )
217
+ val query = df.groupBy($" key" ).agg(typedMax2($" value" ), count($" value" ), typedMax2($" value" ))
218
+ val expected = data.groupBy(_._2).toSeq.map { group =>
219
+ val (key, values) = group
220
+ val valueMax = values.map(_._1).max
221
+ val countValue = values.size
222
+ Row (key, valueMax, countValue, valueMax)
223
+ }
224
+ checkAnswer(query, expected)
225
+ }
226
+ }
227
+
213
228
private def typedMax (column : Column ): Column = {
214
229
val max = TypedMax (column.expr, nullable = false )
215
230
Column (max.toAggregateExpression())
@@ -219,6 +234,11 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
219
234
val max = TypedMax (column.expr, nullable = true )
220
235
Column (max.toAggregateExpression())
221
236
}
237
+
238
+ private def typedMax2 (column : Column ): Column = {
239
+ val max = TypedMax2 (column.expr, nullable = false )
240
+ Column (max.toAggregateExpression())
241
+ }
222
242
}
223
243
224
244
object TypedImperativeAggregateSuite {
@@ -299,5 +319,87 @@ object TypedImperativeAggregateSuite {
299
319
}
300
320
}
301
321
322
+ /**
323
+ * Calculate the max value with object aggregation buffer. This stores class MaxValue
324
+ * in aggregation buffer.
325
+ */
326
+ private case class TypedMax2 (
327
+ child : Expression ,
328
+ nullable : Boolean = false ,
329
+ mutableAggBufferOffset : Int = 0 ,
330
+ inputAggBufferOffset : Int = 0 )
331
+ extends TypedImperativeAggregate [MaxValue ] with ImplicitCastInputTypes {
332
+
333
+
334
+ var maxValueBuffer : MaxValue = null
335
+ override def createAggregationBuffer (): MaxValue = {
336
+ // Returns Int.MinValue if all inputs are null
337
+ maxValueBuffer = new MaxValue (Int .MinValue )
338
+ maxValueBuffer
339
+ }
340
+
341
+ override def update (buffer : MaxValue , input : InternalRow ): MaxValue = {
342
+ child.eval(input) match {
343
+ case inputValue : Int =>
344
+ if (inputValue > buffer.value) {
345
+ buffer.value = inputValue
346
+ buffer.isValueSet = true
347
+ }
348
+ case null => // skip
349
+ }
350
+ buffer
351
+ }
352
+
353
+ override def merge (bufferMax : MaxValue , inputMax : MaxValue ): MaxValue = {
354
+ // The below if condition will throw a Null Pointer Exception if initialize() is not called
355
+ if (maxValueBuffer.isValueSet) {
356
+ // do nothing
357
+ }
358
+ if (inputMax.value > bufferMax.value) {
359
+ bufferMax.value = inputMax.value
360
+ bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet
361
+ }
362
+ bufferMax
363
+ }
364
+
365
+ override def eval (bufferMax : MaxValue ): Any = {
366
+ if (nullable && bufferMax.isValueSet == false ) {
367
+ null
368
+ } else {
369
+ bufferMax.value
370
+ }
371
+ }
372
+
373
+ override lazy val deterministic : Boolean = true
374
+
375
+ override def children : Seq [Expression ] = Seq (child)
376
+
377
+ override def inputTypes : Seq [AbstractDataType ] = Seq (IntegerType )
378
+
379
+ override def dataType : DataType = IntegerType
380
+
381
+ override def withNewMutableAggBufferOffset (newOffset : Int ): TypedImperativeAggregate [MaxValue ] =
382
+ copy(mutableAggBufferOffset = newOffset)
383
+
384
+ override def withNewInputAggBufferOffset (newOffset : Int ): TypedImperativeAggregate [MaxValue ] =
385
+ copy(inputAggBufferOffset = newOffset)
386
+
387
+ override def serialize (buffer : MaxValue ): Array [Byte ] = {
388
+ val out = new ByteArrayOutputStream ()
389
+ val stream = new DataOutputStream (out)
390
+ stream.writeBoolean(buffer.isValueSet)
391
+ stream.writeInt(buffer.value)
392
+ out.toByteArray
393
+ }
394
+
395
+ override def deserialize (storageFormat : Array [Byte ]): MaxValue = {
396
+ val in = new ByteArrayInputStream (storageFormat)
397
+ val stream = new DataInputStream (in)
398
+ val isValueSet = stream.readBoolean()
399
+ val value = stream.readInt()
400
+ new MaxValue (value, isValueSet)
401
+ }
402
+ }
403
+
302
404
private class MaxValue (var value : Int , var isValueSet : Boolean = false )
303
405
}
0 commit comments