diff --git a/dev/diffs/4.0.0-preview1.diff b/dev/diffs/4.0.0-preview1.diff index dfd57ce8f..e1abb0a35 100644 --- a/dev/diffs/4.0.0-preview1.diff +++ b/dev/diffs/4.0.0-preview1.diff @@ -900,7 +900,7 @@ index 56c364e2084..11779ee3b4b 100644 withTable("dt") { sql("create table dt using parquet as select 9000000000BD as d") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala -index 68f14f13bbd..4b8e967102f 100644 +index 68f14f13bbd..174636cefb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -22,10 +22,11 @@ import scala.collection.mutable.ArrayBuffer @@ -938,16 +938,41 @@ index 68f14f13bbd..4b8e967102f 100644 } assert(exchanges.size === 1) } -@@ -2668,7 +2675,8 @@ class SubquerySuite extends QueryTest - } - } - -- test("SPARK-43402: FileSourceScanExec supports push down data filter with scalar subquery") { -+ test("SPARK-43402: FileSourceScanExec supports push down data filter with scalar subquery", -+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/551")) { +@@ -2672,18 +2679,26 @@ class SubquerySuite extends QueryTest def checkFileSourceScan(query: String, answer: Seq[Row]): Unit = { val df = sql(query) checkAnswer(df, answer) +- val fileSourceScanExec = collect(df.queryExecution.executedPlan) { +- case f: FileSourceScanExec => f ++ val dataSourceScanExec = collect(df.queryExecution.executedPlan) { ++ case f: FileSourceScanLike => f ++ case c: CometScanExec => c + } + sparkContext.listenerBus.waitUntilEmpty() +- assert(fileSourceScanExec.size === 1) +- val scalarSubquery = fileSourceScanExec.head.dataFilters.flatMap(_.collect { +- case s: ScalarSubquery => s +- }) ++ assert(dataSourceScanExec.size === 1) ++ val scalarSubquery = dataSourceScanExec.head match { ++ case f: FileSourceScanLike => ++ f.dataFilters.flatMap(_.collect { ++ case s: ScalarSubquery => s ++ }) ++ case c: CometScanExec => ++ c.dataFilters.flatMap(_.collect { ++ case s: ScalarSubquery => s ++ }) ++ } + assert(scalarSubquery.length === 1) + assert(scalarSubquery.head.plan.isInstanceOf[ReusedSubqueryExec]) +- assert(fileSourceScanExec.head.metrics("numFiles").value === 1) +- assert(fileSourceScanExec.head.metrics("numOutputRows").value === answer.size) ++ assert(dataSourceScanExec.head.metrics("numFiles").value === 1) ++ assert(dataSourceScanExec.head.metrics("numOutputRows").value === answer.size) + } + + withTable("t1", "t2") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 1de535df246..cc7ffc4eeb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala index ca99e36b8..adaefb238 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala @@ -135,10 +135,7 @@ case class CometScanExec( (wrapped.outputPartitioning, wrapped.outputOrdering) @transient - private lazy val pushedDownFilters = { - val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) - dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) - } + private lazy val pushedDownFilters = getPushedDownFilters(relation, dataFilters) override lazy val metadata: Map[String, String] = if (wrapped == null) Map.empty else wrapped.metadata diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala index 3c3e8c471..e4a5584aa 100644 --- a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala +++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala @@ -24,11 +24,12 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType trait ShimCometScanExec { @@ -67,4 +68,9 @@ trait ShimCometScanExec { maxSplitBytes: Long, partitionValues: InternalRow): Seq[PartitionedFile] = PartitionedFileUtil.splitFiles(sparkSession, file, isSplitable, maxSplitBytes, partitionValues) + + protected def getPushedDownFilters(relation: HadoopFsRelation , dataFilters: Seq[Expression]): Seq[Filter] = { + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) + dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + } } diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala index 48b9c8086..3edc43278 100644 --- a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala +++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala @@ -23,11 +23,12 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, FileSourceConstantMetadataAttribute, Literal} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} +import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, ScalarSubquery} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType trait ShimCometScanExec { @@ -68,4 +69,30 @@ trait ShimCometScanExec { maxSplitBytes: Long, partitionValues: InternalRow): Seq[PartitionedFile] = PartitionedFileUtil.splitFiles(file, isSplitable, maxSplitBytes, partitionValues) + + protected def getPushedDownFilters(relation: HadoopFsRelation , dataFilters: Seq[Expression]): Seq[Filter] = { + translateToV1Filters(relation, dataFilters, _.toLiteral) + } + + // From Spark FileSourceScanLike + private def translateToV1Filters(relation: HadoopFsRelation, + dataFilters: Seq[Expression], + scalarSubqueryToLiteral: ScalarSubquery => Literal): Seq[Filter] = { + val scalarSubqueryReplaced = dataFilters.map(_.transform { + // Replace scalar subquery to literal so that `DataSourceStrategy.translateFilter` can + // support translating it. + case scalarSubquery: ScalarSubquery => scalarSubqueryToLiteral(scalarSubquery) + }) + + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) + // `dataFilters` should not include any constant metadata col filters + // because the metadata struct has been flatted in FileSourceStrategy + // and thus metadata col filters are invalid to be pushed down. Metadata that is generated + // during the scan can be used for filters. + scalarSubqueryReplaced.filterNot(_.references.exists { + case FileSourceConstantMetadataAttribute(_) => true + case _ => false + }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + } + } diff --git a/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala index 900b19895..a25be3bc0 100644 --- a/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala +++ b/spark/src/main/spark-pre-3.5/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala @@ -20,16 +20,15 @@ package org.apache.spark.sql.comet.shims import org.apache.comet.shims.ShimFileFormat - import org.apache.hadoop.fs.{FileStatus, Path} - import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} -import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, HadoopFsRelation, PartitionDirectory, PartitionedFile} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType trait ShimCometScanExec { @@ -102,4 +101,10 @@ trait ShimCometScanExec { maxSplitBytes: Long, partitionValues: InternalRow): Seq[PartitionedFile] = PartitionedFileUtil.splitFiles(sparkSession, file, filePath, isSplitable, maxSplitBytes, partitionValues) + + protected def getPushedDownFilters(relation: HadoopFsRelation , dataFilters: Seq[Expression]): Seq[Filter] = { + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) + dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + } + }