Skip to content

Commit ba7db7f

Browse files
committed
Handle null keys in hash-based comparator, and add tests for collisions
1 parent ef4e397 commit ba7db7f

File tree

2 files changed

+137
-4
lines changed

2 files changed

+137
-4
lines changed

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import org.apache.spark.storage.BlockId
4040
* @param aggregator optional Aggregator with combine functions to use for merging data
4141
* @param partitioner optional partitioner; if given, sort by partition ID and then key
4242
* @param ordering optional ordering to sort keys within each partition
43-
* @param serializer serializer to use
43+
* @param serializer serializer to use when spilling to disk
4444
*/
4545
private[spark] class ExternalSorter[K, V, C](
4646
aggregator: Option[Aggregator[K, V, C]] = None,
@@ -95,7 +95,11 @@ private[spark] class ExternalSorter[K, V, C](
9595
// non-equal keys also have this, so we need to do a later pass to find truly equal keys).
9696
// Note that we ignore this if no aggregator is given.
9797
private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
98-
override def compare(a: K, b: K): Int = a.hashCode() - b.hashCode()
98+
override def compare(a: K, b: K): Int = {
99+
val h1 = if (a == null) 0 else a.hashCode()
100+
val h2 = if (b == null) 0 else b.hashCode()
101+
h1 - h2
102+
}
99103
})
100104

101105
private val sortWithinPartitions = ordering.isDefined || aggregator.isDefined
@@ -215,7 +219,6 @@ private[spark] class ExternalSorter[K, V, C](
215219
val batchSizes = new ArrayBuffer[Long]
216220

217221
// How many elements we have in each partition
218-
// TODO: this could become a sparser data structure
219222
val elementsPerPartition = new Array[Long](numPartitions)
220223

221224
// Flush the disk writer's contents to disk, and update relevant variables

core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
package org.apache.spark.util.collection
1919

20+
import scala.collection.mutable.ArrayBuffer
21+
2022
import org.scalatest.FunSuite
2123

2224
import org.apache.spark._
2325
import org.apache.spark.SparkContext._
24-
import scala.Some
2526

2627
class ExternalSorterSuite extends FunSuite with LocalSparkContext {
2728
test("spilling in local cluster") {
@@ -332,4 +333,133 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
332333
}).toSeq
333334
assert(results === expected)
334335
}
336+
337+
test("spilling with hash collisions") {
338+
val conf = new SparkConf(true)
339+
conf.set("spark.shuffle.memoryFraction", "0.001")
340+
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
341+
342+
def createCombiner(i: String) = ArrayBuffer[String](i)
343+
def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i
344+
def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) =
345+
buffer1 ++= buffer2
346+
347+
val agg = new Aggregator[String, String, ArrayBuffer[String]](
348+
createCombiner _, mergeValue _, mergeCombiners _)
349+
350+
val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
351+
Some(agg), None, None, None)
352+
353+
val collisionPairs = Seq(
354+
("Aa", "BB"), // 2112
355+
("to", "v1"), // 3707
356+
("variants", "gelato"), // -1249574770
357+
("Teheran", "Siblings"), // 231609873
358+
("misused", "horsemints"), // 1069518484
359+
("isohel", "epistolaries"), // -1179291542
360+
("righto", "buzzards"), // -931102253
361+
("hierarch", "crinolines"), // -1732884796
362+
("inwork", "hypercatalexes"), // -1183663690
363+
("wainages", "presentencing"), // 240183619
364+
("trichothecenes", "locular"), // 339006536
365+
("pomatoes", "eructation") // 568647356
366+
)
367+
368+
collisionPairs.foreach { case (w1, w2) =>
369+
// String.hashCode is documented to use a specific algorithm, but check just in case
370+
assert(w1.hashCode === w2.hashCode)
371+
}
372+
373+
val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++
374+
collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap)
375+
376+
sorter.write(toInsert)
377+
378+
// A map of collision pairs in both directions
379+
val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap
380+
381+
// Avoid map.size or map.iterator.length because this destructively sorts the underlying map
382+
var count = 0
383+
384+
val it = sorter.iterator
385+
while (it.hasNext) {
386+
val kv = it.next()
387+
val expectedValue = ArrayBuffer[String](collisionPairsMap.getOrElse(kv._1, kv._1))
388+
assert(kv._2.equals(expectedValue))
389+
count += 1
390+
}
391+
assert(count === 100000 + collisionPairs.size * 2)
392+
}
393+
394+
test("spilling with many hash collisions") {
395+
val conf = new SparkConf(true)
396+
conf.set("spark.shuffle.memoryFraction", "0.0001")
397+
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
398+
399+
val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
400+
val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None)
401+
402+
// Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
403+
// problems if the map fails to group together the objects with the same code (SPARK-2043).
404+
val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1)
405+
sorter.write(toInsert.iterator)
406+
407+
val it = sorter.iterator
408+
var count = 0
409+
while (it.hasNext) {
410+
val kv = it.next()
411+
assert(kv._2 === 10)
412+
count += 1
413+
}
414+
assert(count === 10000)
415+
}
416+
417+
test("spilling with hash collisions using the Int.MaxValue key") {
418+
val conf = new SparkConf(true)
419+
conf.set("spark.shuffle.memoryFraction", "0.001")
420+
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
421+
422+
def createCombiner(i: Int) = ArrayBuffer[Int](i)
423+
def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i
424+
def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2
425+
426+
val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners)
427+
val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None)
428+
429+
sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
430+
431+
val it = sorter.iterator
432+
while (it.hasNext) {
433+
// Should not throw NoSuchElementException
434+
it.next()
435+
}
436+
}
437+
438+
test("spilling with null keys and values") {
439+
val conf = new SparkConf(true)
440+
conf.set("spark.shuffle.memoryFraction", "0.001")
441+
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
442+
443+
def createCombiner(i: String) = ArrayBuffer[String](i)
444+
def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i
445+
def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]) = buf1 ++= buf2
446+
447+
val agg = new Aggregator[String, String, ArrayBuffer[String]](
448+
createCombiner, mergeValue, mergeCombiners)
449+
450+
val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
451+
Some(agg), None, None, None)
452+
453+
sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
454+
(null.asInstanceOf[String], "1"),
455+
("1", null.asInstanceOf[String]),
456+
(null.asInstanceOf[String], null.asInstanceOf[String])
457+
))
458+
459+
val it = sorter.iterator
460+
while (it.hasNext) {
461+
// Should not throw NullPointerException
462+
it.next()
463+
}
464+
}
335465
}

0 commit comments

Comments
 (0)