@@ -19,12 +19,12 @@ package org.apache.spark.util.collection
1919
2020import scala .collection .mutable .ArrayBuffer
2121
22- import org .scalatest .FunSuite
22+ import org .scalatest .{ PrivateMethodTester , FunSuite }
2323
2424import org .apache .spark ._
2525import org .apache .spark .SparkContext ._
2626
27- class ExternalSorterSuite extends FunSuite with LocalSparkContext {
27+ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester {
2828 private def createSparkConf (loadDefaults : Boolean ): SparkConf = {
2929 val conf = new SparkConf (loadDefaults)
3030 // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
@@ -36,6 +36,16 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
3636 conf
3737 }
3838
39+ private def assertBypassedMergeSort (sorter : ExternalSorter [_, _, _]): Unit = {
40+ val bypassMergeSort = PrivateMethod [Boolean ](' bypassMergeSort )
41+ assert(sorter.invokePrivate(bypassMergeSort()), " sorter did not bypass merge-sort" )
42+ }
43+
44+ private def assertDidNotBypassMergeSort (sorter : ExternalSorter [_, _, _]): Unit = {
45+ val bypassMergeSort = PrivateMethod [Boolean ](' bypassMergeSort )
46+ assert(! sorter.invokePrivate(bypassMergeSort()), " sorter bypassed merge-sort" )
47+ }
48+
3949 test(" empty data stream" ) {
4050 val conf = new SparkConf (false )
4151 conf.set(" spark.shuffle.memoryFraction" , " 0.001" )
@@ -123,7 +133,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
123133
124134 val sorter = new ExternalSorter [Int , Int , Int ](
125135 None , Some (new HashPartitioner (7 )), Some (ord), None )
126- assert( ! sorter.bypassMergeSort, " sorter bypassed merge-sort " )
136+ assertDidNotBypassMergeSort( sorter)
127137 sorter.insertAll(elements)
128138 assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0 ) // Make sure it spilled
129139 val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
@@ -147,7 +157,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
147157
148158 val sorter = new ExternalSorter [Int , Int , Int ](
149159 None , Some (new HashPartitioner (7 )), None , None )
150- assert (sorter.bypassMergeSort, " sorter did not bypass merge-sort " )
160+ assertBypassedMergeSort (sorter)
151161 sorter.insertAll(elements)
152162 assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0 ) // Make sure it spilled
153163 val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
@@ -314,15 +324,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
314324
315325 val sorter = new ExternalSorter [Int , Int , Int ](
316326 None , Some (new HashPartitioner (3 )), Some (ord), None )
317- assert( ! sorter.bypassMergeSort, " sorter bypassed merge-sort " )
327+ assertDidNotBypassMergeSort( sorter)
318328 sorter.insertAll((0 until 100000 ).iterator.map(i => (i, i)))
319329 assert(diskBlockManager.getAllFiles().length > 0 )
320330 sorter.stop()
321331 assert(diskBlockManager.getAllBlocks().length === 0 )
322332
323333 val sorter2 = new ExternalSorter [Int , Int , Int ](
324334 None , Some (new HashPartitioner (3 )), Some (ord), None )
325- assert( ! sorter2.bypassMergeSort, " sorter bypassed merge-sort " )
335+ assertDidNotBypassMergeSort( sorter2)
326336 sorter2.insertAll((0 until 100000 ).iterator.map(i => (i, i)))
327337 assert(diskBlockManager.getAllFiles().length > 0 )
328338 assert(sorter2.iterator.toSet === (0 until 100000 ).map(i => (i, i)).toSet)
@@ -338,14 +348,14 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
338348 val diskBlockManager = SparkEnv .get.blockManager.diskBlockManager
339349
340350 val sorter = new ExternalSorter [Int , Int , Int ](None , Some (new HashPartitioner (3 )), None , None )
341- assert (sorter.bypassMergeSort, " sorter did not bypass merge-sort " )
351+ assertBypassedMergeSort (sorter)
342352 sorter.insertAll((0 until 100000 ).iterator.map(i => (i, i)))
343353 assert(diskBlockManager.getAllFiles().length > 0 )
344354 sorter.stop()
345355 assert(diskBlockManager.getAllBlocks().length === 0 )
346356
347357 val sorter2 = new ExternalSorter [Int , Int , Int ](None , Some (new HashPartitioner (3 )), None , None )
348- assert (sorter2.bypassMergeSort, " sorter did not bypass merge-sort " )
358+ assertBypassedMergeSort (sorter2)
349359 sorter2.insertAll((0 until 100000 ).iterator.map(i => (i, i)))
350360 assert(diskBlockManager.getAllFiles().length > 0 )
351361 assert(sorter2.iterator.toSet === (0 until 100000 ).map(i => (i, i)).toSet)
@@ -364,7 +374,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
364374
365375 val sorter = new ExternalSorter [Int , Int , Int ](
366376 None , Some (new HashPartitioner (3 )), Some (ord), None )
367- assert( ! sorter.bypassMergeSort, " sorter bypassed merge-sort " )
377+ assertDidNotBypassMergeSort( sorter)
368378 intercept[SparkException ] {
369379 sorter.insertAll((0 until 100000 ).iterator.map(i => {
370380 if (i == 99990 ) {
@@ -386,7 +396,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
386396 val diskBlockManager = SparkEnv .get.blockManager.diskBlockManager
387397
388398 val sorter = new ExternalSorter [Int , Int , Int ](None , Some (new HashPartitioner (3 )), None , None )
389- assert (sorter.bypassMergeSort, " sorter did not bypass merge-sort " )
399+ assertBypassedMergeSort (sorter)
390400 intercept[SparkException ] {
391401 sorter.insertAll((0 until 100000 ).iterator.map(i => {
392402 if (i == 99990 ) {
@@ -681,20 +691,20 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
681691
682692 val sorter1 = new ExternalSorter [Int , Int , Int ](
683693 None , Some (new HashPartitioner (FEW_PARTITIONS )), None , None )
684- assert (sorter1.bypassMergeSort, " sorter did not bypass merge-sort " )
694+ assertBypassedMergeSort (sorter1)
685695
686696 val sorter2 = new ExternalSorter [Int , Int , Int ](
687697 None , Some (new HashPartitioner (MANY_PARTITIONS )), None , None )
688- assert( ! sorter2.bypassMergeSort, " sorter bypassed merge-sort " )
698+ assertDidNotBypassMergeSort( sorter2)
689699
690700 // Sorters with an ordering or aggregator: should not bypass even if they have few partitions
691701
692702 val sorter3 = new ExternalSorter [Int , Int , Int ](
693703 None , Some (new HashPartitioner (FEW_PARTITIONS )), Some (ord), None )
694- assert( ! sorter3.bypassMergeSort, " sorter bypassed merge-sort " )
704+ assertDidNotBypassMergeSort( sorter3)
695705
696706 val sorter4 = new ExternalSorter [Int , Int , Int ](
697707 Some (agg), Some (new HashPartitioner (FEW_PARTITIONS )), None , None )
698- assert( ! sorter4.bypassMergeSort, " sorter bypassed merge-sort " )
708+ assertDidNotBypassMergeSort( sorter4)
699709 }
700710}
0 commit comments