Skip to content

Commit 19c7c7e

Browse files
cloud-fangatorsmile
authored andcommitted
[SPARK-23301][SQL] data source column pruning should work for arbitrary expressions
## What changes were proposed in this pull request? This PR fixes a mistake in the `PushDownOperatorsToDataSource` rule, the column pruning logic is incorrect about `Project`. ## How was this patch tested? a new test case for column pruning with arbitrary expressions, and improve the existing tests to make sure the `PushDownOperatorsToDataSource` really works. Author: Wenchen Fan <wenchen@databricks.com> Closes #20476 from cloud-fan/push-down.
1 parent b3a0428 commit 19c7c7e

File tree

3 files changed

+155
-40
lines changed

3 files changed

+155
-40
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.datasources.v2
1919

20-
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper}
20+
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper}
2121
import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject
2222
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
2323
import org.apache.spark.sql.catalyst.rules.Rule
@@ -81,35 +81,34 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel
8181

8282
// TODO: add more push down rules.
8383

84-
// TODO: nested fields pruning
85-
def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = {
86-
plan match {
87-
case Project(projectList, child) =>
88-
val required = projectList.filter(requiredByParent.contains).flatMap(_.references)
89-
pushDownRequiredColumns(child, required)
90-
91-
case Filter(condition, child) =>
92-
val required = requiredByParent ++ condition.references
93-
pushDownRequiredColumns(child, required)
94-
95-
case DataSourceV2Relation(fullOutput, reader) => reader match {
96-
case r: SupportsPushDownRequiredColumns =>
97-
// Match original case of attributes.
98-
val attrMap = AttributeMap(fullOutput.zip(fullOutput))
99-
val requiredColumns = requiredByParent.map(attrMap)
100-
r.pruneColumns(requiredColumns.toStructType)
101-
case _ =>
102-
}
84+
pushDownRequiredColumns(filterPushed, filterPushed.outputSet)
85+
// After column pruning, we may have redundant PROJECT nodes in the query plan, remove them.
86+
RemoveRedundantProject(filterPushed)
87+
}
88+
89+
// TODO: nested fields pruning
90+
private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = {
91+
plan match {
92+
case Project(projectList, child) =>
93+
val required = projectList.flatMap(_.references)
94+
pushDownRequiredColumns(child, AttributeSet(required))
95+
96+
case Filter(condition, child) =>
97+
val required = requiredByParent ++ condition.references
98+
pushDownRequiredColumns(child, required)
10399

104-
// TODO: there may be more operators can be used to calculate required columns, we can add
105-
// more and more in the future.
106-
case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output))
100+
case relation: DataSourceV2Relation => relation.reader match {
101+
case reader: SupportsPushDownRequiredColumns =>
102+
val requiredColumns = relation.output.filter(requiredByParent.contains)
103+
reader.pruneColumns(requiredColumns.toStructType)
104+
105+
case _ =>
107106
}
108-
}
109107

110-
pushDownRequiredColumns(filterPushed, filterPushed.output)
111-
// After column pruning, we may have redundant PROJECT nodes in the query plan, remove them.
112-
RemoveRedundantProject(filterPushed)
108+
// TODO: there may be more operators that can be used to calculate the required columns. We
109+
// can add more and more in the future.
110+
case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet))
111+
}
113112
}
114113

115114
/**

sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@
3232

3333
public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport {
3434

35-
class Reader implements DataSourceReader, SupportsPushDownRequiredColumns,
35+
public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns,
3636
SupportsPushDownFilters {
3737

38-
private StructType requiredSchema = new StructType().add("i", "int").add("j", "int");
39-
private Filter[] filters = new Filter[0];
38+
// Exposed for testing.
39+
public StructType requiredSchema = new StructType().add("i", "int").add("j", "int");
40+
public Filter[] filters = new Filter[0];
4041

4142
@Override
4243
public StructType readSchema() {
@@ -50,8 +51,26 @@ public void pruneColumns(StructType requiredSchema) {
5051

5152
@Override
5253
public Filter[] pushFilters(Filter[] filters) {
53-
this.filters = filters;
54-
return new Filter[0];
54+
Filter[] supported = Arrays.stream(filters).filter(f -> {
55+
if (f instanceof GreaterThan) {
56+
GreaterThan gt = (GreaterThan) f;
57+
return gt.attribute().equals("i") && gt.value() instanceof Integer;
58+
} else {
59+
return false;
60+
}
61+
}).toArray(Filter[]::new);
62+
63+
Filter[] unsupported = Arrays.stream(filters).filter(f -> {
64+
if (f instanceof GreaterThan) {
65+
GreaterThan gt = (GreaterThan) f;
66+
return !gt.attribute().equals("i") || !(gt.value() instanceof Integer);
67+
} else {
68+
return true;
69+
}
70+
}).toArray(Filter[]::new);
71+
72+
this.filters = supported;
73+
return unsupported;
5574
}
5675

5776
@Override

sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala

Lines changed: 105 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ import java.util.{ArrayList, List => JList}
2121

2222
import test.org.apache.spark.sql.sources.v2._
2323

24-
import org.apache.spark.{SparkConf, SparkException}
25-
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
24+
import org.apache.spark.SparkException
25+
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
2626
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
27+
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
2728
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
2829
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
30+
import org.apache.spark.sql.functions._
2931
import org.apache.spark.sql.sources.{Filter, GreaterThan}
3032
import org.apache.spark.sql.sources.v2.reader._
3133
import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning}
@@ -48,14 +50,72 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
4850
}
4951

5052
test("advanced implementation") {
53+
def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = {
54+
query.queryExecution.executedPlan.collect {
55+
case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader]
56+
}.head
57+
}
58+
59+
def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = {
60+
query.queryExecution.executedPlan.collect {
61+
case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader]
62+
}.head
63+
}
64+
5165
Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls =>
5266
withClue(cls.getName) {
5367
val df = spark.read.format(cls.getName).load()
5468
checkAnswer(df, (0 until 10).map(i => Row(i, -i)))
55-
checkAnswer(df.select('j), (0 until 10).map(i => Row(-i)))
56-
checkAnswer(df.filter('i > 3), (4 until 10).map(i => Row(i, -i)))
57-
checkAnswer(df.select('j).filter('i > 6), (7 until 10).map(i => Row(-i)))
58-
checkAnswer(df.select('i).filter('i > 10), Nil)
69+
70+
val q1 = df.select('j)
71+
checkAnswer(q1, (0 until 10).map(i => Row(-i)))
72+
if (cls == classOf[AdvancedDataSourceV2]) {
73+
val reader = getReader(q1)
74+
assert(reader.filters.isEmpty)
75+
assert(reader.requiredSchema.fieldNames === Seq("j"))
76+
} else {
77+
val reader = getJavaReader(q1)
78+
assert(reader.filters.isEmpty)
79+
assert(reader.requiredSchema.fieldNames === Seq("j"))
80+
}
81+
82+
val q2 = df.filter('i > 3)
83+
checkAnswer(q2, (4 until 10).map(i => Row(i, -i)))
84+
if (cls == classOf[AdvancedDataSourceV2]) {
85+
val reader = getReader(q2)
86+
assert(reader.filters.flatMap(_.references).toSet == Set("i"))
87+
assert(reader.requiredSchema.fieldNames === Seq("i", "j"))
88+
} else {
89+
val reader = getJavaReader(q2)
90+
assert(reader.filters.flatMap(_.references).toSet == Set("i"))
91+
assert(reader.requiredSchema.fieldNames === Seq("i", "j"))
92+
}
93+
94+
val q3 = df.select('i).filter('i > 6)
95+
checkAnswer(q3, (7 until 10).map(i => Row(i)))
96+
if (cls == classOf[AdvancedDataSourceV2]) {
97+
val reader = getReader(q3)
98+
assert(reader.filters.flatMap(_.references).toSet == Set("i"))
99+
assert(reader.requiredSchema.fieldNames === Seq("i"))
100+
} else {
101+
val reader = getJavaReader(q3)
102+
assert(reader.filters.flatMap(_.references).toSet == Set("i"))
103+
assert(reader.requiredSchema.fieldNames === Seq("i"))
104+
}
105+
106+
val q4 = df.select('j).filter('j < -10)
107+
checkAnswer(q4, Nil)
108+
if (cls == classOf[AdvancedDataSourceV2]) {
109+
val reader = getReader(q4)
110+
// 'j < 10 is not supported by the testing data source.
111+
assert(reader.filters.isEmpty)
112+
assert(reader.requiredSchema.fieldNames === Seq("j"))
113+
} else {
114+
val reader = getJavaReader(q4)
115+
// 'j < 10 is not supported by the testing data source.
116+
assert(reader.filters.isEmpty)
117+
assert(reader.requiredSchema.fieldNames === Seq("j"))
118+
}
59119
}
60120
}
61121
}
@@ -223,6 +283,39 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
223283
val df2 = df.select(($"i" + 1).as("k"), $"j")
224284
checkAnswer(df.join(df2, "j"), (0 until 10).map(i => Row(-i, i, i + 1)))
225285
}
286+
287+
test("SPARK-23301: column pruning with arbitrary expressions") {
288+
def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = {
289+
query.queryExecution.executedPlan.collect {
290+
case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader]
291+
}.head
292+
}
293+
294+
val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
295+
296+
val q1 = df.select('i + 1)
297+
checkAnswer(q1, (1 until 11).map(i => Row(i)))
298+
val reader1 = getReader(q1)
299+
assert(reader1.requiredSchema.fieldNames === Seq("i"))
300+
301+
val q2 = df.select(lit(1))
302+
checkAnswer(q2, (0 until 10).map(i => Row(1)))
303+
val reader2 = getReader(q2)
304+
assert(reader2.requiredSchema.isEmpty)
305+
306+
// 'j === 1 can't be pushed down, but we should still be able do column pruning
307+
val q3 = df.filter('j === -1).select('j * 2)
308+
checkAnswer(q3, Row(-2))
309+
val reader3 = getReader(q3)
310+
assert(reader3.filters.isEmpty)
311+
assert(reader3.requiredSchema.fieldNames === Seq("j"))
312+
313+
// column pruning should work with other operators.
314+
val q4 = df.sort('i).limit(1).select('i + 1)
315+
checkAnswer(q4, Row(1))
316+
val reader4 = getReader(q4)
317+
assert(reader4.requiredSchema.fieldNames === Seq("i"))
318+
}
226319
}
227320

228321
class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
@@ -270,8 +363,12 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {
270363
}
271364

272365
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
273-
this.filters = filters
274-
Array.empty
366+
val (supported, unsupported) = filters.partition {
367+
case GreaterThan("i", _: Int) => true
368+
case _ => false
369+
}
370+
this.filters = supported
371+
unsupported
275372
}
276373

277374
override def pushedFilters(): Array[Filter] = filters

0 commit comments

Comments
 (0)