Skip to content

Commit a8c818d

Browse files
committed
Refines tests
1 parent 1d01074 commit a8c818d

File tree

3 files changed

+70
-31
lines changed

3 files changed

+70
-31
lines changed

sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,29 @@ private[sql] case class InMemoryRelation(
6666
batchStats.value.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum
6767
}
6868

69+
// Statistics propagation contracts:
70+
// 1. Non-null `_statistics` must reflect the actual statistics of the underlying data
71+
// 2. Only propagate statistics when `_statistics` is non-null
72+
private def statisticsToBePropagated = if (_statistics == null) {
73+
val updatedStats = statistics
74+
if (_statistics == null) null else updatedStats
75+
} else {
76+
_statistics
77+
}
78+
6979
override def statistics = if (_statistics == null) {
7080
if (batchStats.value.isEmpty) {
81+
// Underlying columnar RDD hasn't been materialized, no useful statistics information
82+
// available, return the default statistics.
7183
Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes)
7284
} else {
85+
// Underlying columnar RDD has been materialized, required information has also been collected
86+
// via the `batchStats` accumulator, compute the final statistics, and update `_statistics`.
7387
_statistics = Statistics(sizeInBytes = computeSizeInBytes)
7488
_statistics
7589
}
7690
} else {
91+
// Pre-computed statistics
7792
_statistics
7893
}
7994

@@ -129,7 +144,7 @@ private[sql] case class InMemoryRelation(
129144
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
130145
InMemoryRelation(
131146
newOutput, useCompression, batchSize, storageLevel, child)(
132-
_cachedColumnBuffers, if (_statistics == null) statistics else _statistics)
147+
_cachedColumnBuffers, statisticsToBePropagated)
133148
}
134149

135150
override def children = Seq.empty
@@ -142,7 +157,7 @@ private[sql] case class InMemoryRelation(
142157
storageLevel,
143158
child)(
144159
_cachedColumnBuffers,
145-
if (_statistics == null) statistics else _statistics).asInstanceOf[this.type]
160+
statisticsToBePropagated).asInstanceOf[this.type]
146161
}
147162

148163
def cachedColumnBuffers = _cachedColumnBuffers

sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ class ColumnStatsSuite extends FunSuite {
6161
assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
6262
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
6363
assertResult(10, "Wrong null count")(stats(2))
64+
assertResult(20, "Wrong row count")(stats(3))
65+
assertResult(stats(4), "Wrong size in bytes") {
66+
rows.map { row =>
67+
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
68+
}.sum
69+
}
6470
}
6571
}
6672
}

sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,19 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
2222
import org.apache.spark.sql._
2323
import org.apache.spark.sql.test.TestSQLContext._
2424

25-
case class IntegerData(i: Int)
26-
2725
class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
2826
val originalColumnBatchSize = columnBatchSize
2927
val originalInMemoryPartitionPruning = inMemoryPartitionPruning
3028

3129
override protected def beforeAll(): Unit = {
3230
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
3331
setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
34-
val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData)
35-
rawData.registerTempTable("intData")
32+
33+
val rawData = sparkContext.makeRDD((1 to 100).map { key =>
34+
val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
35+
TestData(key, string)
36+
}, 5)
37+
rawData.registerTempTable("testData")
3638

3739
// Enable in-memory partition pruning
3840
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
@@ -44,48 +46,64 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
4446
}
4547

4648
before {
47-
cacheTable("intData")
49+
cacheTable("testData")
4850
}
4951

5052
after {
51-
uncacheTable("intData")
53+
uncacheTable("testData")
5254
}
5355

5456
// Comparisons
55-
checkBatchPruning("i = 1", Seq(1), 1, 1)
56-
checkBatchPruning("1 = i", Seq(1), 1, 1)
57-
checkBatchPruning("i < 12", 1 to 11, 1, 2)
58-
checkBatchPruning("i <= 11", 1 to 11, 1, 2)
59-
checkBatchPruning("i > 88", 89 to 100, 1, 2)
60-
checkBatchPruning("i >= 89", 89 to 100, 1, 2)
61-
checkBatchPruning("12 > i", 1 to 11, 1, 2)
62-
checkBatchPruning("11 >= i", 1 to 11, 1, 2)
63-
checkBatchPruning("88 < i", 89 to 100, 1, 2)
64-
checkBatchPruning("89 <= i", 89 to 100, 1, 2)
57+
checkBatchPruning("SELECT key FROM testData WHERE key = 1", 1, 1)(Seq(1))
58+
checkBatchPruning("SELECT key FROM testData WHERE 1 = key", 1, 1)(Seq(1))
59+
checkBatchPruning("SELECT key FROM testData WHERE key < 12", 1, 2)(1 to 11)
60+
checkBatchPruning("SELECT key FROM testData WHERE key <= 11", 1, 2)(1 to 11)
61+
checkBatchPruning("SELECT key FROM testData WHERE key > 88", 1, 2)(89 to 100)
62+
checkBatchPruning("SELECT key FROM testData WHERE key >= 89", 1, 2)(89 to 100)
63+
checkBatchPruning("SELECT key FROM testData WHERE 12 > key", 1, 2)(1 to 11)
64+
checkBatchPruning("SELECT key FROM testData WHERE 11 >= key", 1, 2)(1 to 11)
65+
checkBatchPruning("SELECT key FROM testData WHERE 88 < key", 1, 2)(89 to 100)
66+
checkBatchPruning("SELECT key FROM testData WHERE 89 <= key", 1, 2)(89 to 100)
67+
68+
// IS NULL
69+
checkBatchPruning("SELECT key FROM testData WHERE value IS NULL", 5, 5) {
70+
(1 to 10) ++ (21 to 30) ++ (41 to 50) ++ (61 to 70) ++ (81 to 90)
71+
}
72+
73+
// IS NOT NULL
74+
checkBatchPruning("SELECT key FROM testData WHERE value IS NOT NULL", 5, 5) {
75+
(11 to 20) ++ (31 to 40) ++ (51 to 60) ++ (71 to 80) ++ (91 to 100)
76+
}
6577

6678
// Conjunction and disjunction
67-
checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3)
68-
checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2)
69-
checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4)
70-
checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2)
79+
checkBatchPruning("SELECT key FROM testData WHERE key > 8 AND key <= 21", 2, 3)(9 to 21)
80+
checkBatchPruning("SELECT key FROM testData WHERE key < 2 OR key > 99", 2, 2)(Seq(1, 100))
81+
checkBatchPruning("SELECT key FROM testData WHERE key < 2 OR (key > 78 AND key < 92)", 3, 4) {
82+
Seq(1) ++ (79 to 91)
83+
}
7184

7285
// With unsupported predicate
73-
checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2)
74-
checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10)
86+
checkBatchPruning("SELECT key FROM testData WHERE NOT (key < 88)", 1, 2)(88 to 100)
87+
checkBatchPruning("SELECT key FROM testData WHERE key < 12 AND key IS NOT NULL", 1, 2)(1 to 11)
88+
89+
{
90+
val seq = (1 to 30).mkString(", ")
91+
checkBatchPruning(s"SELECT key FROM testData WHERE NOT (key IN ($seq))", 5, 10)(31 to 100)
92+
}
7593

7694
def checkBatchPruning(
77-
filter: String,
78-
expectedQueryResult: Seq[Int],
95+
query: String,
7996
expectedReadPartitions: Int,
80-
expectedReadBatches: Int): Unit = {
97+
expectedReadBatches: Int)(
98+
expectedQueryResult: => Seq[Int]): Unit = {
8199

82-
test(filter) {
83-
val query = sql(s"SELECT * FROM intData WHERE $filter")
100+
test(query) {
101+
val schemaRdd = sql(query)
84102
assertResult(expectedQueryResult.toArray, "Wrong query result") {
85-
query.collect().map(_.head).toArray
103+
schemaRdd.collect().map(_.head).toArray
86104
}
87105

88-
val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect {
106+
val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect {
89107
case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value)
90108
}.head
91109

0 commit comments

Comments
 (0)