@@ -21,11 +21,13 @@ import java.util.{ArrayList, List => JList}
21
21
22
22
import test .org .apache .spark .sql .sources .v2 ._
23
23
24
- import org .apache .spark .{ SparkConf , SparkException }
25
- import org .apache .spark .sql .{AnalysisException , QueryTest , Row }
24
+ import org .apache .spark .SparkException
25
+ import org .apache .spark .sql .{AnalysisException , DataFrame , QueryTest , Row }
26
26
import org .apache .spark .sql .catalyst .expressions .UnsafeRow
27
+ import org .apache .spark .sql .execution .datasources .v2 .DataSourceV2ScanExec
27
28
import org .apache .spark .sql .execution .exchange .ShuffleExchangeExec
28
29
import org .apache .spark .sql .execution .vectorized .OnHeapColumnVector
30
+ import org .apache .spark .sql .functions ._
29
31
import org .apache .spark .sql .sources .{Filter , GreaterThan }
30
32
import org .apache .spark .sql .sources .v2 .reader ._
31
33
import org .apache .spark .sql .sources .v2 .reader .partitioning .{ClusteredDistribution , Distribution , Partitioning }
@@ -48,14 +50,72 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
48
50
}
49
51
50
52
test(" advanced implementation" ) {
53
+ def getReader (query : DataFrame ): AdvancedDataSourceV2 # Reader = {
54
+ query.queryExecution.executedPlan.collect {
55
+ case d : DataSourceV2ScanExec => d.reader.asInstanceOf [AdvancedDataSourceV2 # Reader ]
56
+ }.head
57
+ }
58
+
59
+ def getJavaReader (query : DataFrame ): JavaAdvancedDataSourceV2 # Reader = {
60
+ query.queryExecution.executedPlan.collect {
61
+ case d : DataSourceV2ScanExec => d.reader.asInstanceOf [JavaAdvancedDataSourceV2 # Reader ]
62
+ }.head
63
+ }
64
+
51
65
Seq (classOf [AdvancedDataSourceV2 ], classOf [JavaAdvancedDataSourceV2 ]).foreach { cls =>
52
66
withClue(cls.getName) {
53
67
val df = spark.read.format(cls.getName).load()
54
68
checkAnswer(df, (0 until 10 ).map(i => Row (i, - i)))
55
- checkAnswer(df.select(' j ), (0 until 10 ).map(i => Row (- i)))
56
- checkAnswer(df.filter(' i > 3 ), (4 until 10 ).map(i => Row (i, - i)))
57
- checkAnswer(df.select(' j ).filter(' i > 6 ), (7 until 10 ).map(i => Row (- i)))
58
- checkAnswer(df.select(' i ).filter(' i > 10 ), Nil )
69
+
70
+ val q1 = df.select(' j )
71
+ checkAnswer(q1, (0 until 10 ).map(i => Row (- i)))
72
+ if (cls == classOf [AdvancedDataSourceV2 ]) {
73
+ val reader = getReader(q1)
74
+ assert(reader.filters.isEmpty)
75
+ assert(reader.requiredSchema.fieldNames === Seq (" j" ))
76
+ } else {
77
+ val reader = getJavaReader(q1)
78
+ assert(reader.filters.isEmpty)
79
+ assert(reader.requiredSchema.fieldNames === Seq (" j" ))
80
+ }
81
+
82
+ val q2 = df.filter(' i > 3 )
83
+ checkAnswer(q2, (4 until 10 ).map(i => Row (i, - i)))
84
+ if (cls == classOf [AdvancedDataSourceV2 ]) {
85
+ val reader = getReader(q2)
86
+ assert(reader.filters.flatMap(_.references).toSet == Set (" i" ))
87
+ assert(reader.requiredSchema.fieldNames === Seq (" i" , " j" ))
88
+ } else {
89
+ val reader = getJavaReader(q2)
90
+ assert(reader.filters.flatMap(_.references).toSet == Set (" i" ))
91
+ assert(reader.requiredSchema.fieldNames === Seq (" i" , " j" ))
92
+ }
93
+
94
+ val q3 = df.select(' i ).filter(' i > 6 )
95
+ checkAnswer(q3, (7 until 10 ).map(i => Row (i)))
96
+ if (cls == classOf [AdvancedDataSourceV2 ]) {
97
+ val reader = getReader(q3)
98
+ assert(reader.filters.flatMap(_.references).toSet == Set (" i" ))
99
+ assert(reader.requiredSchema.fieldNames === Seq (" i" ))
100
+ } else {
101
+ val reader = getJavaReader(q3)
102
+ assert(reader.filters.flatMap(_.references).toSet == Set (" i" ))
103
+ assert(reader.requiredSchema.fieldNames === Seq (" i" ))
104
+ }
105
+
106
+ val q4 = df.select(' j ).filter(' j < - 10 )
107
+ checkAnswer(q4, Nil )
108
+ if (cls == classOf [AdvancedDataSourceV2 ]) {
109
+ val reader = getReader(q4)
110
+ // 'j < 10 is not supported by the testing data source.
111
+ assert(reader.filters.isEmpty)
112
+ assert(reader.requiredSchema.fieldNames === Seq (" j" ))
113
+ } else {
114
+ val reader = getJavaReader(q4)
115
+ // 'j < 10 is not supported by the testing data source.
116
+ assert(reader.filters.isEmpty)
117
+ assert(reader.requiredSchema.fieldNames === Seq (" j" ))
118
+ }
59
119
}
60
120
}
61
121
}
@@ -223,6 +283,39 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
223
283
val df2 = df.select(($" i" + 1 ).as(" k" ), $" j" )
224
284
checkAnswer(df.join(df2, " j" ), (0 until 10 ).map(i => Row (- i, i, i + 1 )))
225
285
}
286
+
287
+ test(" SPARK-23301: column pruning with arbitrary expressions" ) {
288
+ def getReader (query : DataFrame ): AdvancedDataSourceV2 # Reader = {
289
+ query.queryExecution.executedPlan.collect {
290
+ case d : DataSourceV2ScanExec => d.reader.asInstanceOf [AdvancedDataSourceV2 # Reader ]
291
+ }.head
292
+ }
293
+
294
+ val df = spark.read.format(classOf [AdvancedDataSourceV2 ].getName).load()
295
+
296
+ val q1 = df.select(' i + 1 )
297
+ checkAnswer(q1, (1 until 11 ).map(i => Row (i)))
298
+ val reader1 = getReader(q1)
299
+ assert(reader1.requiredSchema.fieldNames === Seq (" i" ))
300
+
301
+ val q2 = df.select(lit(1 ))
302
+ checkAnswer(q2, (0 until 10 ).map(i => Row (1 )))
303
+ val reader2 = getReader(q2)
304
+ assert(reader2.requiredSchema.isEmpty)
305
+
306
+ // 'j === 1 can't be pushed down, but we should still be able do column pruning
307
+ val q3 = df.filter(' j === - 1 ).select(' j * 2 )
308
+ checkAnswer(q3, Row (- 2 ))
309
+ val reader3 = getReader(q3)
310
+ assert(reader3.filters.isEmpty)
311
+ assert(reader3.requiredSchema.fieldNames === Seq (" j" ))
312
+
313
+ // column pruning should work with other operators.
314
+ val q4 = df.sort(' i ).limit(1 ).select(' i + 1 )
315
+ checkAnswer(q4, Row (1 ))
316
+ val reader4 = getReader(q4)
317
+ assert(reader4.requiredSchema.fieldNames === Seq (" i" ))
318
+ }
226
319
}
227
320
228
321
class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
@@ -270,8 +363,12 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {
270
363
}
271
364
272
365
override def pushFilters (filters : Array [Filter ]): Array [Filter ] = {
273
- this .filters = filters
274
- Array .empty
366
+ val (supported, unsupported) = filters.partition {
367
+ case GreaterThan (" i" , _ : Int ) => true
368
+ case _ => false
369
+ }
370
+ this .filters = supported
371
+ unsupported
275
372
}
276
373
277
374
override def pushedFilters (): Array [Filter ] = filters
0 commit comments