Skip to content

Commit f9739b9

Browse files
lianchengmarmbrus
authored andcommitted
[SPARK-4468][SQL] Backports #3334 to branch-1.1
<!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/3338) <!-- Reviewable:end --> Author: Cheng Lian <lian@databricks.com> Closes #3338 from liancheng/spark-3334-for-1.1 and squashes the following commits: bd17512 [Cheng Lian] Backports #3334 to branch-1.1
1 parent ae9b1f6 commit f9739b9

File tree

2 files changed

+75
-45
lines changed

2 files changed

+75
-45
lines changed

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,22 +213,27 @@ private[sql] object ParquetFilters {
213213
Some(createEqualityFilter(right.name, left, p))
214214
case p @ EqualTo(left: NamedExpression, right: Literal) if !left.nullable =>
215215
Some(createEqualityFilter(left.name, right, p))
216+
216217
case p @ LessThan(left: Literal, right: NamedExpression) if !right.nullable =>
217-
Some(createLessThanFilter(right.name, left, p))
218+
Some(createGreaterThanFilter(right.name, left, p))
218219
case p @ LessThan(left: NamedExpression, right: Literal) if !left.nullable =>
219220
Some(createLessThanFilter(left.name, right, p))
221+
220222
case p @ LessThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable =>
221-
Some(createLessThanOrEqualFilter(right.name, left, p))
223+
Some(createGreaterThanOrEqualFilter(right.name, left, p))
222224
case p @ LessThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable =>
223225
Some(createLessThanOrEqualFilter(left.name, right, p))
226+
224227
case p @ GreaterThan(left: Literal, right: NamedExpression) if !right.nullable =>
225-
Some(createGreaterThanFilter(right.name, left, p))
228+
Some(createLessThanFilter(right.name, left, p))
226229
case p @ GreaterThan(left: NamedExpression, right: Literal) if !left.nullable =>
227230
Some(createGreaterThanFilter(left.name, right, p))
231+
228232
case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable =>
229-
Some(createGreaterThanOrEqualFilter(right.name, left, p))
233+
Some(createLessThanOrEqualFilter(right.name, left, p))
230234
case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable =>
231235
Some(createGreaterThanOrEqualFilter(left.name, right, p))
236+
232237
case _ => None
233238
}
234239
}

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala

Lines changed: 66 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,19 @@
1717

1818
package org.apache.spark.sql.parquet
1919

20+
import org.apache.hadoop.fs.{FileSystem, Path}
21+
import org.apache.hadoop.mapreduce.Job
2022
import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
21-
2223
import parquet.hadoop.ParquetFileWriter
2324
import parquet.hadoop.util.ContextUtil
24-
import org.apache.hadoop.fs.{FileSystem, Path}
25-
import org.apache.hadoop.mapreduce.Job
2625

2726
import org.apache.spark.SparkContext
2827
import org.apache.spark.sql._
29-
import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser}
3028
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute}
3129
import org.apache.spark.sql.catalyst.expressions._
3230
import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType}
3331
import org.apache.spark.sql.catalyst.util.getTempFilePath
32+
import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser}
3433
import org.apache.spark.sql.test.TestSQLContext
3534
import org.apache.spark.sql.test.TestSQLContext._
3635
import org.apache.spark.util.Utils
@@ -453,43 +452,46 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
453452
}
454453

455454
test("create RecordFilter for simple predicates") {
456-
val attribute1 = new AttributeReference("first", IntegerType, false)()
457-
val predicate1 = new EqualTo(attribute1, new Literal(1, IntegerType))
458-
val filter1 = ParquetFilters.createFilter(predicate1)
459-
assert(filter1.isDefined)
460-
assert(filter1.get.predicate == predicate1, "predicates do not match")
461-
assert(filter1.get.isInstanceOf[ComparisonFilter])
462-
val cmpFilter1 = filter1.get.asInstanceOf[ComparisonFilter]
463-
assert(cmpFilter1.columnName == "first", "column name incorrect")
464-
465-
val predicate2 = new LessThan(attribute1, new Literal(4, IntegerType))
466-
val filter2 = ParquetFilters.createFilter(predicate2)
467-
assert(filter2.isDefined)
468-
assert(filter2.get.predicate == predicate2, "predicates do not match")
469-
assert(filter2.get.isInstanceOf[ComparisonFilter])
470-
val cmpFilter2 = filter2.get.asInstanceOf[ComparisonFilter]
471-
assert(cmpFilter2.columnName == "first", "column name incorrect")
472-
473-
val predicate3 = new And(predicate1, predicate2)
474-
val filter3 = ParquetFilters.createFilter(predicate3)
475-
assert(filter3.isDefined)
476-
assert(filter3.get.predicate == predicate3, "predicates do not match")
477-
assert(filter3.get.isInstanceOf[AndFilter])
478-
479-
val predicate4 = new Or(predicate1, predicate2)
480-
val filter4 = ParquetFilters.createFilter(predicate4)
481-
assert(filter4.isDefined)
482-
assert(filter4.get.predicate == predicate4, "predicates do not match")
483-
assert(filter4.get.isInstanceOf[OrFilter])
484-
485-
val attribute2 = new AttributeReference("second", IntegerType, false)()
486-
val predicate5 = new GreaterThan(attribute1, attribute2)
487-
val badfilter = ParquetFilters.createFilter(predicate5)
488-
assert(badfilter.isDefined === false)
489-
490-
val predicate6 = And(GreaterThan(attribute1, attribute2), GreaterThan(attribute1, attribute2))
491-
val badfilter2 = ParquetFilters.createFilter(predicate6)
492-
assert(badfilter2.isDefined === false)
455+
def checkFilter(predicate: Predicate): Option[CatalystFilter] = {
456+
ParquetFilters.createFilter(predicate).map { f =>
457+
assertResult(predicate)(f.predicate)
458+
f
459+
}.orElse {
460+
fail(s"filter $predicate not pushed down")
461+
}
462+
}
463+
464+
def checkComparisonFilter(predicate: Predicate, columnName: String): Unit = {
465+
assertResult(columnName, "column name incorrect") {
466+
checkFilter(predicate).map(_.asInstanceOf[ComparisonFilter].columnName).get
467+
}
468+
}
469+
470+
def checkInvalidFilter(predicate: Predicate): Unit = {
471+
assert(ParquetFilters.createFilter(predicate).isEmpty)
472+
}
473+
474+
val a = 'a.int.notNull
475+
val b = 'b.int.notNull
476+
477+
checkComparisonFilter(a === 1, "a")
478+
checkComparisonFilter(Literal(1) === a, "a")
479+
480+
checkComparisonFilter(a < 4, "a")
481+
checkComparisonFilter(a > 4, "a")
482+
checkComparisonFilter(a <= 4, "a")
483+
checkComparisonFilter(a >= 4, "a")
484+
485+
checkComparisonFilter(Literal(4) > a, "a")
486+
checkComparisonFilter(Literal(4) < a, "a")
487+
checkComparisonFilter(Literal(4) >= a, "a")
488+
checkComparisonFilter(Literal(4) <= a, "a")
489+
490+
checkFilter(a === 1 && a < 4)
491+
checkFilter(a === 1 || a < 4)
492+
493+
checkInvalidFilter(a > b)
494+
checkInvalidFilter((a > b) && (a > b))
493495
}
494496

495497
test("test filter by predicate pushdown") {
@@ -516,6 +518,29 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
516518
assert(result2(49)(1) === 199)
517519
}
518520
}
521+
for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) {
522+
val query1 = sql(s"SELECT * FROM testfiltersource WHERE 150 > $myval AND 100 <= $myval")
523+
assert(
524+
query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
525+
"Top operator should be ParquetTableScan after pushdown")
526+
val result1 = query1.collect()
527+
assert(result1.size === 50)
528+
assert(result1(0)(1) === 100)
529+
assert(result1(49)(1) === 149)
530+
val query2 = sql(s"SELECT * FROM testfiltersource WHERE 150 < $myval AND 200 >= $myval")
531+
assert(
532+
query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
533+
"Top operator should be ParquetTableScan after pushdown")
534+
val result2 = query2.collect()
535+
assert(result2.size === 50)
536+
if (myval == "myint" || myval == "mylong") {
537+
assert(result2(0)(1) === 151)
538+
assert(result2(49)(1) === 200)
539+
} else {
540+
assert(result2(0)(1) === 150)
541+
assert(result2(49)(1) === 199)
542+
}
543+
}
519544
for(myval <- Seq("myint", "mylong")) {
520545
val query3 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190 OR $myval < 10")
521546
assert(

0 commit comments

Comments
 (0)