Skip to content

[SPARK-16596] [SQL] Refactor DataSourceScanExec to do partition discovery at execution instead of planning time #14241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
*/
lazy val allAttributes: AttributeSeq = children.flatMap(_.output)

private def cleanExpression(e: Expression): Expression = e match {
protected def cleanExpression(e: Expression): Expression = e match {
case a: Alias =>
// As the root of the expression, Alias will always take an arbitrary exprId, we need
// to erase that for equality testing.
Expand Down
395 changes: 318 additions & 77 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.DataSourceScanExec.PUSHED_FILTERS
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.{CreateDataSourceTableUtils, DDLUtils, ExecutedCommandExec}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -268,8 +268,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
(a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil

case l @ LogicalRelation(baseRelation: TableScan, _, _) =>
execution.DataSourceScanExec.create(
l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil
RowDataSourceScanExec(
l.output,
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
UnknownPartitioning(0),
Map.empty,
None) :: Nil

case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _),
part, query, overwrite, false) if part.isEmpty =>
Expand Down Expand Up @@ -375,20 +380,20 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// Don't request columns that are only referenced by pushed filters.
.filterNot(handledSet.contains)

val scan = execution.DataSourceScanExec.create(
val scan = RowDataSourceScanExec(
projects.map(_.toAttribute),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation, metadata, relation.metastoreTableIdentifier)
relation.relation, UnknownPartitioning(0), metadata, relation.metastoreTableIdentifier)
filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan)
} else {
// Don't request columns that are only referenced by pushed filters.
val requestedColumns =
(projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq

val scan = execution.DataSourceScanExec.create(
val scan = RowDataSourceScanExec(
requestedColumns,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation, metadata, relation.metastoreTableIdentifier)
relation.relation, UnknownPartitioning(0), metadata, relation.metastoreTableIdentifier)
execution.ProjectExec(
projects, filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@

package org.apache.spark.sql.execution.datasources

import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path}

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
Expand All @@ -29,8 +25,8 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.DataSourceScanExec.{INPUT_PATHS, PUSHED_FILTERS}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.SparkPlan

/**
Expand Down Expand Up @@ -96,8 +92,6 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
val afterScanFilters = filterSet -- partitionKeyFilters
logInfo(s"Post-Scan Filters: ${afterScanFilters.mkString(",")}")

val selectedPartitions = fsRelation.location.listFiles(partitionKeyFilters.toSeq)

val filterAttributes = AttributeSet(afterScanFilters)
val requiredExpressions: Seq[NamedExpression] = filterAttributes.toSeq ++ projects
val requiredAttributes = AttributeSet(requiredExpressions)
Expand All @@ -106,44 +100,21 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
dataColumns
.filter(requiredAttributes.contains)
.filterNot(partitionColumns.contains)
val prunedDataSchema = readDataColumns.toStructType
logInfo(s"Pruned Data Schema: ${prunedDataSchema.simpleString(5)}")
val outputSchema = readDataColumns.toStructType
logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}")

val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter)
logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}")

val readFile: (PartitionedFile) => Iterator[InternalRow] =
fsRelation.fileFormat.buildReaderWithPartitionValues(
sparkSession = fsRelation.sparkSession,
dataSchema = fsRelation.dataSchema,
partitionSchema = fsRelation.partitionSchema,
requiredSchema = prunedDataSchema,
filters = pushedDownFilters,
options = fsRelation.options,
hadoopConf =
fsRelation.sparkSession.sessionState.newHadoopConfWithOptions(fsRelation.options))

val rdd = fsRelation.bucketSpec match {
case Some(bucketing) if fsRelation.sparkSession.sessionState.conf.bucketingEnabled =>
createBucketedReadRDD(bucketing, readFile, selectedPartitions, fsRelation)
case _ =>
createNonBucketedReadRDD(readFile, selectedPartitions, fsRelation)
}

// These metadata values make scan plans uniquely identifiable for equality checking.
val meta = Map(
"PartitionFilters" -> partitionKeyFilters.mkString("[", ", ", "]"),
"Format" -> fsRelation.fileFormat.toString,
"ReadSchema" -> prunedDataSchema.simpleString,
PUSHED_FILTERS -> pushedDownFilters.mkString("[", ", ", "]"),
INPUT_PATHS -> fsRelation.location.paths.mkString(", "))
val outputAttributes = readDataColumns ++ partitionColumns

val scan =
DataSourceScanExec.create(
readDataColumns ++ partitionColumns,
rdd,
new FileSourceScanExec(
fsRelation,
meta,
outputAttributes,
outputSchema,
partitionKeyFilters.toSeq,
pushedDownFilters,
table)

val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And)
Expand All @@ -158,155 +129,4 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {

case _ => Nil
}

/**
* Create an RDD for bucketed reads.
* The non-bucketed variant of this function is [[createNonBucketedReadRDD]].
*
* The algorithm is pretty simple: each RDD partition being returned should include all the files
* with the same bucket id from all the given Hive partitions.
*
* @param bucketSpec the bucketing spec.
* @param readFile a function to read each (part of a) file.
* @param selectedPartitions Hive-style partition that are part of the read.
* @param fsRelation [[HadoopFsRelation]] associated with the read.
*/
private def createBucketedReadRDD(
bucketSpec: BucketSpec,
readFile: (PartitionedFile) => Iterator[InternalRow],
selectedPartitions: Seq[Partition],
fsRelation: HadoopFsRelation): RDD[InternalRow] = {
logInfo(s"Planning with ${bucketSpec.numBuckets} buckets")
val bucketed =
selectedPartitions.flatMap { p =>
p.files.map { f =>
val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen)
PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen, hosts)
}
}.groupBy { f =>
BucketingUtils
.getBucketId(new Path(f.filePath).getName)
.getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
}

val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil))
}

new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions)
}

/**
* Create an RDD for non-bucketed reads.
* The bucketed variant of this function is [[createBucketedReadRDD]].
*
* @param readFile a function to read each (part of a) file.
* @param selectedPartitions Hive-style partition that are part of the read.
* @param fsRelation [[HadoopFsRelation]] associated with the read.
*/
private def createNonBucketedReadRDD(
readFile: (PartitionedFile) => Iterator[InternalRow],
selectedPartitions: Seq[Partition],
fsRelation: HadoopFsRelation): RDD[InternalRow] = {
val defaultMaxSplitBytes =
fsRelation.sparkSession.sessionState.conf.filesMaxPartitionBytes
val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes
val defaultParallelism = fsRelation.sparkSession.sparkContext.defaultParallelism
val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum
val bytesPerCore = totalBytes / defaultParallelism

val maxSplitBytes = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore))
logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
s"open cost is considered as scanning $openCostInBytes bytes.")

val splitFiles = selectedPartitions.flatMap { partition =>
partition.files.flatMap { file =>
val blockLocations = getBlockLocations(file)
if (fsRelation.fileFormat.isSplitable(
fsRelation.sparkSession, fsRelation.options, file.getPath)) {
(0L until file.getLen by maxSplitBytes).map { offset =>
val remaining = file.getLen - offset
val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining
val hosts = getBlockHosts(blockLocations, offset, size)
PartitionedFile(
partition.values, file.getPath.toUri.toString, offset, size, hosts)
}
} else {
val hosts = getBlockHosts(blockLocations, 0, file.getLen)
Seq(PartitionedFile(
partition.values, file.getPath.toUri.toString, 0, file.getLen, hosts))
}
}
}.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)

val partitions = new ArrayBuffer[FilePartition]
val currentFiles = new ArrayBuffer[PartitionedFile]
var currentSize = 0L

/** Close the current partition and move to the next. */
def closePartition(): Unit = {
if (currentFiles.nonEmpty) {
val newPartition =
FilePartition(
partitions.size,
currentFiles.toArray.toSeq) // Copy to a new Array.
partitions.append(newPartition)
}
currentFiles.clear()
currentSize = 0
}

// Assign files to partitions using "First Fit Decreasing" (FFD)
// TODO: consider adding a slop factor here?
splitFiles.foreach { file =>
if (currentSize + file.length > maxSplitBytes) {
closePartition()
}
// Add the given file to the current partition.
currentSize += file.length + openCostInBytes
currentFiles.append(file)
}
closePartition()

new FileScanRDD(fsRelation.sparkSession, readFile, partitions)
}

private def getBlockLocations(file: FileStatus): Array[BlockLocation] = file match {
case f: LocatedFileStatus => f.getBlockLocations
case f => Array.empty[BlockLocation]
}

// Given locations of all blocks of a single file, `blockLocations`, and an `(offset, length)`
// pair that represents a segment of the same file, find out the block that contains the largest
// fraction the segment, and returns location hosts of that block. If no such block can be found,
// returns an empty array.
private def getBlockHosts(
blockLocations: Array[BlockLocation], offset: Long, length: Long): Array[String] = {
val candidates = blockLocations.map {
// The fragment starts from a position within this block
case b if b.getOffset <= offset && offset < b.getOffset + b.getLength =>
b.getHosts -> (b.getOffset + b.getLength - offset).min(length)

// The fragment ends at a position within this block
case b if offset <= b.getOffset && offset + length < b.getLength =>
b.getHosts -> (offset + length - b.getOffset).min(length)

// The fragment fully contains this block
case b if offset <= b.getOffset && b.getOffset + b.getLength <= offset + length =>
b.getHosts -> b.getLength

// The fragment doesn't intersect with this block
case b =>
b.getHosts -> 0L
}.filter { case (hosts, size) =>
size > 0L
}

if (candidates.isEmpty) {
Array.empty[String]
} else {
val (hosts, _) = candidates.maxBy { case (_, size) => size }
hosts
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem}
import org.apache.hadoop.mapreduce.Job

import org.apache.spark.SparkConf
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
Expand Down Expand Up @@ -518,8 +518,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi

def getFileScanRDD(df: DataFrame): FileScanRDD = {
df.queryExecution.executedPlan.collect {
case scan: DataSourceScanExec if scan.rdd.isInstanceOf[FileScanRDD] =>
scan.rdd.asInstanceOf[FileScanRDD]
case scan: DataSourceScanExec if scan.inputRDDs().head.isInstanceOf[FileScanRDD] =>
scan.inputRDDs().head.asInstanceOf[FileScanRDD]
}.headOption.getOrElse {
fail(s"No FileScan in query\n${df.queryExecution}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.execution.BatchedDataSourceScanExec
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT, SingleElement}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -624,16 +624,15 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext

// donot return batch, because whole stage codegen is disabled for wide table (>200 columns)
val df2 = spark.read.parquet(path)
assert(df2.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isEmpty,
"Should not return batch")
val fileScan2 = df2.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
assert(!fileScan2.asInstanceOf[FileSourceScanExec].supportsBatch)
checkAnswer(df2, df)

// return batch
val columns = Seq.tabulate(9) {i => s"c$i"}
val df3 = df2.selectExpr(columns : _*)
assert(
df3.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isDefined,
"Should return batch")
val fileScan3 = df3.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
assert(fileScan3.asInstanceOf[FileSourceScanExec].supportsBatch)
checkAnswer(df3, df.selectExpr(columns : _*))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ class FileStreamSinkSuite extends StreamTest {
/** Check some condition on the partitions of the FileScanRDD generated by a DF */
def checkFileScanPartitions(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = {
val getFileScanRDD = df.queryExecution.executedPlan.collect {
case scan: DataSourceScanExec if scan.rdd.isInstanceOf[FileScanRDD] =>
scan.rdd.asInstanceOf[FileScanRDD]
case scan: DataSourceScanExec if scan.inputRDDs().head.isInstanceOf[FileScanRDD] =>
scan.inputRDDs().head.asInstanceOf[FileScanRDD]
}.headOption.getOrElse {
fail(s"No FileScan in query\n${df.queryExecution}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,11 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
df1.write.parquet(tableDir.getAbsolutePath)

val agged = spark.table("bucketed_table").groupBy("i").count()
val error = intercept[RuntimeException] {
val error = intercept[Exception] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: we cannot catch the proper exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a nested exception, which is quite hard to match. The following assert checks for the right error message, which is the important bit I think.

agged.count()
}

assert(error.toString contains "Invalid bucket file")
assert(error.getCause().toString contains "Invalid bucket file")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -862,8 +862,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.load(path)

val Some(fileScanRDD) = df2.queryExecution.executedPlan.collectFirst {
case scan: DataSourceScanExec if scan.rdd.isInstanceOf[FileScanRDD] =>
scan.rdd.asInstanceOf[FileScanRDD]
case scan: DataSourceScanExec if scan.inputRDDs().head.isInstanceOf[FileScanRDD] =>
scan.inputRDDs().head.asInstanceOf[FileScanRDD]
}

val partitions = fileScanRDD.partitions
Expand Down