Skip to content

[SPARK-30428][SQL] File source V2: support partition pruning #27112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScan
Expand All @@ -34,19 +35,30 @@ case class AvroScan(
dataSchema: StructType,
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap)
extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) {
override def isSplitable(path: Path): Boolean = true

override def createReaderFactory(): PartitionReaderFactory = {
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
val broadcastedConf = sparkSession.sparkContext.broadcast(
new SerializableConfiguration(hadoopConf))
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
AvroPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, readDataSchema, readPartitionSchema, caseSensitiveMap)
}
options: CaseInsensitiveStringMap,
partitionFilters: Seq[Expression] = Seq.empty) extends FileScan {
override def isSplitable(path: Path): Boolean = true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the indent is wrong in AvroScan. Fix it as well.


override def createReaderFactory(): PartitionReaderFactory = {
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
val broadcastedConf = sparkSession.sparkContext.broadcast(
new SerializableConfiguration(hadoopConf))
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
AvroPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, readDataSchema, readPartitionSchema, caseSensitiveMap)
}

override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters)

override def equals(obj: Any): Boolean = obj match {
case a: AvroScan => super.equals(a) && dataSchema == a.dataSchema && options == a.options

case _ => false
}

override def hashCode(): Int = super.hashCode()
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,15 @@ import org.apache.commons.io.FileUtils
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql._
import org.apache.spark.sql.TestingUDT.{IntervalData, NullData, NullUDT}
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.{DataSource, FilePartition}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.v2.avro.AvroScan
import org.apache.spark.util.Utils

abstract class AvroSuite extends QueryTest with SharedSparkSession {
Expand Down Expand Up @@ -1502,8 +1507,75 @@ class AvroV1Suite extends AvroSuite {
}

class AvroV2Suite extends AvroSuite {
import testImplicits._

override protected def sparkConf: SparkConf =
super
.sparkConf
.set(SQLConf.USE_V1_SOURCE_LIST, "")

test("Avro source v2: support partition pruning") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR, but we should think of how to share test cases between the avro suite and FileBasedDataSourceSuite

withTempPath { dir =>
Seq(("a", 1, 2), ("b", 1, 2), ("c", 2, 1))
.toDF("value", "p1", "p2")
.write
.format("avro")
.partitionBy("p1", "p2")
.option("header", true)
.save(dir.getCanonicalPath)
val df = spark
.read
.format("avro")
.option("header", true)
.load(dir.getCanonicalPath)
.where("p1 = 1 and p2 = 2 and value != \"a\"")

val filterCondition = df.queryExecution.optimizedPlan.collectFirst {
case f: Filter => f.condition
}
assert(filterCondition.isDefined)
// The partitions filters should be pushed down and no need to be reevaluated.
assert(filterCondition.get.collectFirst {
case a: AttributeReference if a.name == "p1" || a.name == "p2" => a
}.isEmpty)

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
assert(fileScan.get.planInputPartitions().forall { partition =>
partition.asInstanceOf[FilePartition].files.forall { file =>
file.filePath.contains("p1=1") && file.filePath.contains("p2=2")
}
})
checkAnswer(df, Row("b", 1, 2))
}
}

private def getBatchScanExec(plan: SparkPlan): BatchScanExec = {
plan.find(_.isInstanceOf[BatchScanExec]).get.asInstanceOf[BatchScanExec]
}

test("Avro source v2: same result with different orders of data filters and partition filters") {
withTempPath { path =>
val tmpDir = path.getCanonicalPath
spark
.range(10)
.selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d")
.write
.partitionBy("a", "b")
.format("avro")
.save(tmpDir)
val df = spark.read.format("avro").load(tmpDir)
// partition filters: a > 1 AND b < 9
// data filters: c > 1 AND d < 9
val plan1 = df.where("a > 1 AND b < 9 AND c > 1 AND d < 9").queryExecution.sparkPlan
val plan2 = df.where("b < 9 AND a > 1 AND d < 9 AND c > 1").queryExecution.sparkPlan
assert(plan1.sameResult(plan2))
val scan1 = getBatchScanExec(plan1)
val scan2 = getBatchScanExec(plan2)
assert(scan1.sameResult(scan2))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class SparkOptimizer(

override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] =
// TODO: move SchemaPruning into catalyst
SchemaPruning :: PruneFileSourcePartitions :: V2ScanRelationPushDown :: Nil
SchemaPruning :: V2ScanRelationPushDown :: PruneFileSourcePartitions :: Nil

override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,46 @@

package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.CatalogStatistics
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan, FileTable}
import org.apache.spark.sql.types.StructType

private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {

private def getPartitionKeyFilters(
sparkSession: SparkSession,
relation: LeafNode,
partitionSchema: StructType,
filters: Seq[Expression],
output: Seq[AttributeReference]): ExpressionSet = {
val normalizedFilters = DataSourceStrategy.normalizeExprs(
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output)
val partitionColumns =
relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver)
val partitionSet = AttributeSet(partitionColumns)
ExpressionSet(normalizedFilters.filter { f =>
f.references.subsetOf(partitionSet)
})
}

private def rebuildPhysicalOperation(
projects: Seq[NamedExpression],
filters: Seq[Expression],
relation: LeafNode): Project = {
val withFilter = if (filters.nonEmpty) {
val filterExpression = filters.reduceLeft(And)
Filter(filterExpression, relation)
} else {
relation
}
Project(projects, withFilter)
}

override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case op @ PhysicalOperation(projects, filters,
logicalRelation @
Expand All @@ -39,31 +72,35 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
_,
_))
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
val normalizedFilters = DataSourceStrategy.normalizeExprs(
filters.filterNot(SubqueryExpression.hasSubquery), logicalRelation.output)

val sparkSession = fsRelation.sparkSession
val partitionColumns =
logicalRelation.resolve(
partitionSchema, sparkSession.sessionState.analyzer.resolver)
val partitionSet = AttributeSet(partitionColumns)
val partitionKeyFilters = ExpressionSet(normalizedFilters.filter { f =>
f.references.subsetOf(partitionSet)
})

val partitionKeyFilters = getPartitionKeyFilters(
fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output)
if (partitionKeyFilters.nonEmpty) {
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
val prunedFsRelation =
fsRelation.copy(location = prunedFileIndex)(sparkSession)
fsRelation.copy(location = prunedFileIndex)(fsRelation.sparkSession)
Copy link
Contributor

@guykhazma guykhazma Jan 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to pass also the dataFilters.
This is useful for FileIndex implementations that use the dataFilters to do the file listing.
For example, we use this to provide data skipping for all file based datasources.
I suggest something like this guykhazma@ de3415b

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guykhazma Thanks for the suggestion.
However, the PartitioningAwareFileIndex doesn't use the data filters for listing files. Could you provide an example that the data filters will be useful here?
Also, the data filters are supposed to be pushed down in FileScanBuiler (e.g ORC/Parquet)

Copy link
Contributor

@guykhazma guykhazma Jan 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gengliangwang this is useful for enabling data skipping on all file formats including formats which doesn't support pushdown (e.g CSV, JSON) by replacing the FileIndex implementation with a FileIndex which use also the dataFilters to filter the file listing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the old v1 code path, let's not touch it in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and for v2 code path, the data filters are already pushed in the rule V2ScanRelationPushDown

// Change table stats based on the sizeInBytes of pruned files
val withStats = logicalRelation.catalogTable.map(_.copy(
stats = Some(CatalogStatistics(sizeInBytes = BigInt(prunedFileIndex.sizeInBytes)))))
val prunedLogicalRelation = logicalRelation.copy(
relation = prunedFsRelation, catalogTable = withStats)
// Keep partition-pruning predicates so that they are visible in physical planning
val filterExpression = filters.reduceLeft(And)
val filter = Filter(filterExpression, prunedLogicalRelation)
Project(projects, filter)
rebuildPhysicalOperation(projects, filters, prunedLogicalRelation)
} else {
op
}

case op @ PhysicalOperation(projects, filters,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CSV datasource in #26973 doesn't fall to the case but parquet/orc does. And withPartitionFilters is not invoke for CSV. What's wrong with CSV when filters push down is enabled?

v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output))
if filters.nonEmpty && scan.readDataSchema.nonEmpty =>
val partitionKeyFilters = getPartitionKeyFilters(scan.sparkSession,
v2Relation, scan.readPartitionSchema, filters, output)
if (partitionKeyFilters.nonEmpty) {
val prunedV2Relation =
v2Relation.copy(scan = scan.withPartitionFilters(partitionKeyFilters.toSeq))
// The pushed down partition filters don't need to be reevaluated.
val afterScanFilters =
ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty)
rebuildPhysicalOperation(projects, afterScanFilters.toSeq, prunedV2Relation)
} else {
op
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.IO_WARNING_LARGEFILETHRESHOLD
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan, Statistics, SupportsReportStatistics}
import org.apache.spark.sql.execution.PartitionedFileUtil
Expand All @@ -32,20 +33,38 @@ import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

abstract class FileScan(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
readDataSchema: StructType,
readPartitionSchema: StructType)
extends Scan
with Batch with SupportsReportStatistics with Logging {
trait FileScan extends Scan with Batch with SupportsReportStatistics with Logging {
/**
* Returns whether a file with `path` could be split or not.
*/
def isSplitable(path: Path): Boolean = {
false
}

def sparkSession: SparkSession

def fileIndex: PartitioningAwareFileIndex

/**
* Returns the required data schema
*/
def readDataSchema: StructType

/**
* Returns the required partition schema
*/
def readPartitionSchema: StructType

/**
* Returns the filters that can be use for partition pruning
*/
def partitionFilters: Seq[Expression]

/**
* Create a new `FileScan` instance from the current one with different `partitionFilters`.
*/
def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan

/**
* If a file with `path` is unsplittable, return the unsplittable reason,
* otherwise return `None`.
Expand All @@ -55,11 +74,24 @@ abstract class FileScan(
"undefined"
}

protected def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")

override def equals(obj: Any): Boolean = obj match {
case f: FileScan =>
fileIndex == f.fileIndex && readSchema == f.readSchema
ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters)

case _ => false
}

override def hashCode(): Int = getClass.hashCode()

override def description(): String = {
val locationDesc =
fileIndex.getClass.getSimpleName + fileIndex.rootPaths.mkString("[", ", ", "]")
val metadata: Map[String, String] = Map(
"ReadSchema" -> readDataSchema.catalogString,
"PartitionFilters" -> seqToString(partitionFilters),
"Location" -> locationDesc)
val metadataStr = metadata.toSeq.sorted.map {
case (key, value) =>
Expand All @@ -71,7 +103,7 @@ abstract class FileScan(
}

protected def partitions: Seq[FilePartition] = {
val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty)
val selectedPartitions = fileIndex.listFiles(partitionFilters, Seq.empty)
Copy link
Contributor

@guykhazma guykhazma Jan 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gengliangwang @cloud-fan continuing the discussion from above (the comment was on the wrong line).
The V2ScanRelationPushDown rule will pushdown the dataFilters only to datasources which support pushdown by implementing the SupportsPushDownFilters trait.
Datasources such as csv and json do not implement the SupportsPushDownFilters trait. In order to support data skipping uniformly for all file based data sources, we override the listFiles method in a FileIndex implementation, which consults external metadata and prunes the list of files.
The suggestion is to make the necessary changes to have the dataFilters passed to the listFiles as well.
Otherwise, one would have to create a new datasource implementation in order to support each file based datasource that doesn't have a built in pushdown mechanism.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me. @gengliangwang what do you think?

At least you can disable v2 file source to bring back this feature.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it makes sense if there is a fileIndex can use the dataFilters.
@guykhazma could you create a PR for this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gengliangwang @cloud-fan sure, thanks.
I have opened this PR

val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)
val partitionAttributes = fileIndex.partitionSchema.toAttributes
val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@ import org.apache.spark.util.Utils

abstract class TextBasedFileScan(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap)
extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) {
options: CaseInsensitiveStringMap) extends FileScan {
@transient private lazy val codecFactory: CompressionCodecFactory = new CompressionCodecFactory(
sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils}
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan
import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration
Expand All @@ -37,8 +37,9 @@ case class CSVScan(
dataSchema: StructType,
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap)
extends TextBasedFileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema, options) {
options: CaseInsensitiveStringMap,
partitionFilters: Seq[Expression] = Seq.empty)
extends TextBasedFileScan(sparkSession, options) {

private lazy val parsedOptions: CSVOptions = new CSVOptions(
options.asScala.toMap,
Expand Down Expand Up @@ -87,4 +88,15 @@ case class CSVScan(
CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, readDataSchema, readPartitionSchema, parsedOptions)
}

override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters)

override def equals(obj: Any): Boolean = obj match {
case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options

case _ => false
}

override def hashCode(): Int = super.hashCode()
}
Loading