@@ -22,17 +22,19 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
22
22
import org .apache .spark .sql ._
23
23
import org .apache .spark .sql .test .TestSQLContext ._
24
24
25
- case class IntegerData (i : Int )
26
-
27
25
class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
28
26
val originalColumnBatchSize = columnBatchSize
29
27
val originalInMemoryPartitionPruning = inMemoryPartitionPruning
30
28
31
29
override protected def beforeAll (): Unit = {
32
30
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
33
31
setConf(SQLConf .COLUMN_BATCH_SIZE , " 10" )
34
- val rawData = sparkContext.makeRDD(1 to 100 , 5 ).map(IntegerData )
35
- rawData.registerTempTable(" intData" )
32
+
33
+ val rawData = sparkContext.makeRDD((1 to 100 ).map { key =>
34
+ val string = if (((key - 1 ) / 10 ) % 2 == 0 ) null else key.toString
35
+ TestData (key, string)
36
+ }, 5 )
37
+ rawData.registerTempTable(" testData" )
36
38
37
39
// Enable in-memory partition pruning
38
40
setConf(SQLConf .IN_MEMORY_PARTITION_PRUNING , " true" )
@@ -44,48 +46,64 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
44
46
}
45
47
46
48
before {
47
- cacheTable(" intData " )
49
+ cacheTable(" testData " )
48
50
}
49
51
50
52
after {
51
- uncacheTable(" intData " )
53
+ uncacheTable(" testData " )
52
54
}
53
55
54
56
// Comparisons
55
- checkBatchPruning(" i = 1" , Seq (1 ), 1 , 1 )
56
- checkBatchPruning(" 1 = i" , Seq (1 ), 1 , 1 )
57
- checkBatchPruning(" i < 12" , 1 to 11 , 1 , 2 )
58
- checkBatchPruning(" i <= 11" , 1 to 11 , 1 , 2 )
59
- checkBatchPruning(" i > 88" , 89 to 100 , 1 , 2 )
60
- checkBatchPruning(" i >= 89" , 89 to 100 , 1 , 2 )
61
- checkBatchPruning(" 12 > i" , 1 to 11 , 1 , 2 )
62
- checkBatchPruning(" 11 >= i" , 1 to 11 , 1 , 2 )
63
- checkBatchPruning(" 88 < i" , 89 to 100 , 1 , 2 )
64
- checkBatchPruning(" 89 <= i" , 89 to 100 , 1 , 2 )
57
+ checkBatchPruning(" SELECT key FROM testData WHERE key = 1" , 1 , 1 )(Seq (1 ))
58
+ checkBatchPruning(" SELECT key FROM testData WHERE 1 = key" , 1 , 1 )(Seq (1 ))
59
+ checkBatchPruning(" SELECT key FROM testData WHERE key < 12" , 1 , 2 )(1 to 11 )
60
+ checkBatchPruning(" SELECT key FROM testData WHERE key <= 11" , 1 , 2 )(1 to 11 )
61
+ checkBatchPruning(" SELECT key FROM testData WHERE key > 88" , 1 , 2 )(89 to 100 )
62
+ checkBatchPruning(" SELECT key FROM testData WHERE key >= 89" , 1 , 2 )(89 to 100 )
63
+ checkBatchPruning(" SELECT key FROM testData WHERE 12 > key" , 1 , 2 )(1 to 11 )
64
+ checkBatchPruning(" SELECT key FROM testData WHERE 11 >= key" , 1 , 2 )(1 to 11 )
65
+ checkBatchPruning(" SELECT key FROM testData WHERE 88 < key" , 1 , 2 )(89 to 100 )
66
+ checkBatchPruning(" SELECT key FROM testData WHERE 89 <= key" , 1 , 2 )(89 to 100 )
67
+
68
+ // IS NULL
69
+ checkBatchPruning(" SELECT key FROM testData WHERE value IS NULL" , 5 , 5 ) {
70
+ (1 to 10 ) ++ (21 to 30 ) ++ (41 to 50 ) ++ (61 to 70 ) ++ (81 to 90 )
71
+ }
72
+
73
+ // IS NOT NULL
74
+ checkBatchPruning(" SELECT key FROM testData WHERE value IS NOT NULL" , 5 , 5 ) {
75
+ (11 to 20 ) ++ (31 to 40 ) ++ (51 to 60 ) ++ (71 to 80 ) ++ (91 to 100 )
76
+ }
65
77
66
78
// Conjunction and disjunction
67
- checkBatchPruning(" i > 8 AND i <= 21" , 9 to 21 , 2 , 3 )
68
- checkBatchPruning(" i < 2 OR i > 99" , Seq (1 , 100 ), 2 , 2 )
69
- checkBatchPruning(" i < 2 OR (i > 78 AND i < 92)" , Seq (1 ) ++ (79 to 91 ), 3 , 4 )
70
- checkBatchPruning(" NOT (i < 88)" , 88 to 100 , 1 , 2 )
79
+ checkBatchPruning(" SELECT key FROM testData WHERE key > 8 AND key <= 21" , 2 , 3 )(9 to 21 )
80
+ checkBatchPruning(" SELECT key FROM testData WHERE key < 2 OR key > 99" , 2 , 2 )(Seq (1 , 100 ))
81
+ checkBatchPruning(" SELECT key FROM testData WHERE key < 2 OR (key > 78 AND key < 92)" , 3 , 4 ) {
82
+ Seq (1 ) ++ (79 to 91 )
83
+ }
71
84
72
85
// With unsupported predicate
73
- checkBatchPruning(" i < 12 AND i IS NOT NULL" , 1 to 11 , 1 , 2 )
74
- checkBatchPruning(s " NOT (i in ( ${(1 to 30 ).mkString(" ," )})) " , 31 to 100 , 5 , 10 )
86
+ checkBatchPruning(" SELECT key FROM testData WHERE NOT (key < 88)" , 1 , 2 )(88 to 100 )
87
+ checkBatchPruning(" SELECT key FROM testData WHERE key < 12 AND key IS NOT NULL" , 1 , 2 )(1 to 11 )
88
+
89
+ {
90
+ val seq = (1 to 30 ).mkString(" , " )
91
+ checkBatchPruning(s " SELECT key FROM testData WHERE NOT (key IN ( $seq)) " , 5 , 10 )(31 to 100 )
92
+ }
75
93
76
94
def checkBatchPruning (
77
- filter : String ,
78
- expectedQueryResult : Seq [Int ],
95
+ query : String ,
79
96
expectedReadPartitions : Int ,
80
- expectedReadBatches : Int ): Unit = {
97
+ expectedReadBatches : Int )(
98
+ expectedQueryResult : => Seq [Int ]): Unit = {
81
99
82
- test(filter ) {
83
- val query = sql(s " SELECT * FROM intData WHERE $filter " )
100
+ test(query ) {
101
+ val schemaRdd = sql(query )
84
102
assertResult(expectedQueryResult.toArray, " Wrong query result" ) {
85
- query .collect().map(_.head).toArray
103
+ schemaRdd .collect().map(_.head).toArray
86
104
}
87
105
88
- val (readPartitions, readBatches) = query .queryExecution.executedPlan.collect {
106
+ val (readPartitions, readBatches) = schemaRdd .queryExecution.executedPlan.collect {
89
107
case in : InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value)
90
108
}.head
91
109
0 commit comments