Skip to content

Commit

Permalink
feat: add scalar subquery pushdown to scan (#678)
Browse files Browse the repository at this point in the history
## Which issue does this PR close?
Part of #372  and #551 

## Rationale for this change
With Spark 4.0, the `SubquerySuite` in Spark fails as Comet scan did not support the scala subquery feature.

## What changes are included in this PR?
Adds the support for scalar subquery pushdown into Comet scan 

## How are these changes tested?
Existing Spark/sql unit tests in `SubquerySuite`
  • Loading branch information
parthchandra authored Jul 19, 2024
1 parent e8765d4 commit 5806b82
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 19 deletions.
41 changes: 33 additions & 8 deletions dev/diffs/4.0.0-preview1.diff
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ index 56c364e2084..11779ee3b4b 100644
withTable("dt") {
sql("create table dt using parquet as select 9000000000BD as d")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 68f14f13bbd..4b8e967102f 100644
index 68f14f13bbd..174636cefb5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -22,10 +22,11 @@ import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -938,16 +938,41 @@ index 68f14f13bbd..4b8e967102f 100644
}
assert(exchanges.size === 1)
}
@@ -2668,7 +2675,8 @@ class SubquerySuite extends QueryTest
}
}

- test("SPARK-43402: FileSourceScanExec supports push down data filter with scalar subquery") {
+ test("SPARK-43402: FileSourceScanExec supports push down data filter with scalar subquery",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/551")) {
@@ -2672,18 +2679,26 @@ class SubquerySuite extends QueryTest
def checkFileSourceScan(query: String, answer: Seq[Row]): Unit = {
val df = sql(query)
checkAnswer(df, answer)
- val fileSourceScanExec = collect(df.queryExecution.executedPlan) {
- case f: FileSourceScanExec => f
+ val dataSourceScanExec = collect(df.queryExecution.executedPlan) {
+ case f: FileSourceScanLike => f
+ case c: CometScanExec => c
}
sparkContext.listenerBus.waitUntilEmpty()
- assert(fileSourceScanExec.size === 1)
- val scalarSubquery = fileSourceScanExec.head.dataFilters.flatMap(_.collect {
- case s: ScalarSubquery => s
- })
+ assert(dataSourceScanExec.size === 1)
+ val scalarSubquery = dataSourceScanExec.head match {
+ case f: FileSourceScanLike =>
+ f.dataFilters.flatMap(_.collect {
+ case s: ScalarSubquery => s
+ })
+ case c: CometScanExec =>
+ c.dataFilters.flatMap(_.collect {
+ case s: ScalarSubquery => s
+ })
+ }
assert(scalarSubquery.length === 1)
assert(scalarSubquery.head.plan.isInstanceOf[ReusedSubqueryExec])
- assert(fileSourceScanExec.head.metrics("numFiles").value === 1)
- assert(fileSourceScanExec.head.metrics("numOutputRows").value === answer.size)
+ assert(dataSourceScanExec.head.metrics("numFiles").value === 1)
+ assert(dataSourceScanExec.head.metrics("numOutputRows").value === answer.size)
}

withTable("t1", "t2") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index 1de535df246..cc7ffc4eeb3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ case class CometScanExec(
(wrapped.outputPartitioning, wrapped.outputOrdering)

@transient
private lazy val pushedDownFilters = {
val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation)
dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
}
private lazy val pushedDownFilters = getPushedDownFilters(relation, dataFilters)

override lazy val metadata: Map[String, String] =
if (wrapped == null) Map.empty else wrapped.metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

trait ShimCometScanExec {
Expand Down Expand Up @@ -67,4 +68,9 @@ trait ShimCometScanExec {
maxSplitBytes: Long,
partitionValues: InternalRow): Seq[PartitionedFile] =
PartitionedFileUtil.splitFiles(sparkSession, file, isSplitable, maxSplitBytes, partitionValues)

protected def getPushedDownFilters(relation: HadoopFsRelation , dataFilters: Seq[Expression]): Seq[Filter] = {
val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation)
dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, FileSourceConstantMetadataAttribute, Literal}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, ScalarSubquery}
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

trait ShimCometScanExec {
Expand Down Expand Up @@ -68,4 +69,30 @@ trait ShimCometScanExec {
maxSplitBytes: Long,
partitionValues: InternalRow): Seq[PartitionedFile] =
PartitionedFileUtil.splitFiles(file, isSplitable, maxSplitBytes, partitionValues)

protected def getPushedDownFilters(relation: HadoopFsRelation , dataFilters: Seq[Expression]): Seq[Filter] = {
translateToV1Filters(relation, dataFilters, _.toLiteral)
}

// From Spark FileSourceScanLike
private def translateToV1Filters(relation: HadoopFsRelation,
dataFilters: Seq[Expression],
scalarSubqueryToLiteral: ScalarSubquery => Literal): Seq[Filter] = {
val scalarSubqueryReplaced = dataFilters.map(_.transform {
// Replace scalar subquery to literal so that `DataSourceStrategy.translateFilter` can
// support translating it.
case scalarSubquery: ScalarSubquery => scalarSubqueryToLiteral(scalarSubquery)
})

val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation)
// `dataFilters` should not include any constant metadata col filters
// because the metadata struct has been flatted in FileSourceStrategy
// and thus metadata col filters are invalid to be pushed down. Metadata that is generated
// during the scan can be used for filters.
scalarSubqueryReplaced.filterNot(_.references.exists {
case FileSourceConstantMetadataAttribute(_) => true
case _ => false
}).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@
package org.apache.spark.sql.comet.shims

import org.apache.comet.shims.ShimFileFormat

import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.SparkException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

trait ShimCometScanExec {
Expand Down Expand Up @@ -102,4 +101,10 @@ trait ShimCometScanExec {
maxSplitBytes: Long,
partitionValues: InternalRow): Seq[PartitionedFile] =
PartitionedFileUtil.splitFiles(sparkSession, file, filePath, isSplitable, maxSplitBytes, partitionValues)

protected def getPushedDownFilters(relation: HadoopFsRelation , dataFilters: Seq[Expression]): Seq[Filter] = {
val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation)
dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
}

}

0 comments on commit 5806b82

Please sign in to comment.