Skip to content

Commit 02e9f93

Browse files
gengliangwangcloud-fan
authored andcommitted
[SPARK-27384][SQL] File source V2: Prune unnecessary partition columns
## What changes were proposed in this pull request? When scanning file sources, we can prune unnecessary partition columns on constructing input partitions, so that: 1. Reduce the data transformation from Driver to Executors 2. Make it easier to implement columnar batch readers, since the partition columns are already pruned. ## How was this patch tested? Existing unit tests. Closes #24296 from gengliangwang/prunePartitionValue. Authored-by: Gengliang Wang <gengliang.wang@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 18b36ee commit 02e9f93

File tree

10 files changed

+114
-62
lines changed

10 files changed

+114
-62
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,6 @@ abstract class FilePartitionReaderFactory extends PartitionReaderFactory {
4646
def buildColumnarReader(partitionedFile: PartitionedFile): PartitionReader[ColumnarBatch] = {
4747
throw new UnsupportedOperationException("Cannot create columnar reader.")
4848
}
49-
50-
protected def getReadDataSchema(
51-
readSchema: StructType,
52-
partitionSchema: StructType,
53-
isCaseSensitive: Boolean): StructType = {
54-
val partitionNameSet =
55-
partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
56-
val fields = readSchema.fields.filterNot { field =>
57-
partitionNameSet.contains(PartitioningUtils.getColName(field, isCaseSensitive))
58-
}
59-
60-
StructType(fields)
61-
}
6249
}
6350

6451
// A compound class for combining file and its corresponding reader.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@
1616
*/
1717
package org.apache.spark.sql.execution.datasources.v2
1818

19+
import java.util.Locale
20+
1921
import org.apache.hadoop.fs.Path
2022

21-
import org.apache.spark.sql.SparkSession
23+
import org.apache.spark.sql.{AnalysisException, SparkSession}
24+
import org.apache.spark.sql.catalyst.expressions.AttributeReference
25+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
2226
import org.apache.spark.sql.execution.PartitionedFileUtil
2327
import org.apache.spark.sql.execution.datasources._
2428
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan}
@@ -28,8 +32,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
2832
abstract class FileScan(
2933
sparkSession: SparkSession,
3034
fileIndex: PartitioningAwareFileIndex,
31-
readSchema: StructType,
32-
options: CaseInsensitiveStringMap) extends Scan with Batch {
35+
readDataSchema: StructType,
36+
readPartitionSchema: StructType) extends Scan with Batch {
3337
/**
3438
* Returns whether a file with `path` could be split or not.
3539
*/
@@ -40,7 +44,23 @@ abstract class FileScan(
4044
protected def partitions: Seq[FilePartition] = {
4145
val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty)
4246
val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)
47+
val partitionAttributes = fileIndex.partitionSchema.toAttributes
48+
val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap
49+
val readPartitionAttributes = readPartitionSchema.map { readField =>
50+
attributeMap.get(normalizeName(readField.name)).getOrElse {
51+
throw new AnalysisException(s"Can't find required partition column ${readField.name} " +
52+
s"in partition schema ${fileIndex.partitionSchema}")
53+
}
54+
}
55+
lazy val partitionValueProject =
56+
GenerateUnsafeProjection.generate(readPartitionAttributes, partitionAttributes)
4357
val splitFiles = selectedPartitions.flatMap { partition =>
58+
// Prune partition values if part of the partition columns are not required.
59+
val partitionValues = if (readPartitionAttributes != partitionAttributes) {
60+
partitionValueProject(partition.values).copy()
61+
} else {
62+
partition.values
63+
}
4464
partition.files.flatMap { file =>
4565
val filePath = file.getPath
4666
PartitionedFileUtil.splitFiles(
@@ -49,7 +69,7 @@ abstract class FileScan(
4969
filePath = filePath,
5070
isSplitable = isSplitable(filePath),
5171
maxSplitBytes = maxSplitBytes,
52-
partitionValues = partition.values
72+
partitionValues = partitionValues
5373
)
5474
}.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
5575
}
@@ -61,4 +81,17 @@ abstract class FileScan(
6181
}
6282

6383
override def toBatch: Batch = this
84+
85+
override def readSchema(): StructType =
86+
StructType(readDataSchema.fields ++ readPartitionSchema.fields)
87+
88+
private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
89+
90+
private def normalizeName(name: String): String = {
91+
if (isCaseSensitive) {
92+
name
93+
} else {
94+
name.toLowerCase(Locale.ROOT)
95+
}
96+
}
6497
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,44 @@
1616
*/
1717
package org.apache.spark.sql.execution.datasources.v2
1818

19-
import org.apache.spark.sql.sources.v2.reader.{ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
19+
import org.apache.spark.sql.SparkSession
20+
import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitioningUtils}
21+
import org.apache.spark.sql.sources.v2.reader.{ScanBuilder, SupportsPushDownRequiredColumns}
2022
import org.apache.spark.sql.types.StructType
2123

22-
abstract class FileScanBuilder(schema: StructType)
23-
extends ScanBuilder
24-
with SupportsPushDownRequiredColumns {
25-
protected var readSchema = schema
24+
abstract class FileScanBuilder(
25+
sparkSession: SparkSession,
26+
fileIndex: PartitioningAwareFileIndex,
27+
dataSchema: StructType) extends ScanBuilder with SupportsPushDownRequiredColumns {
28+
private val partitionSchema = fileIndex.partitionSchema
29+
private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
30+
protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields)
2631

2732
override def pruneColumns(requiredSchema: StructType): Unit = {
28-
this.readSchema = requiredSchema
33+
this.requiredSchema = requiredSchema
2934
}
35+
36+
protected def readDataSchema(): StructType = {
37+
val requiredNameSet = createRequiredNameSet()
38+
val fields = dataSchema.fields.filter { field =>
39+
val colName = PartitioningUtils.getColName(field, isCaseSensitive)
40+
requiredNameSet.contains(colName) && !partitionNameSet.contains(colName)
41+
}
42+
StructType(fields)
43+
}
44+
45+
protected def readPartitionSchema(): StructType = {
46+
val requiredNameSet = createRequiredNameSet()
47+
val fields = partitionSchema.fields.filter { field =>
48+
val colName = PartitioningUtils.getColName(field, isCaseSensitive)
49+
requiredNameSet.contains(colName)
50+
}
51+
StructType(fields)
52+
}
53+
54+
private def createRequiredNameSet(): Set[String] =
55+
requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
56+
57+
private val partitionNameSet: Set[String] =
58+
partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
3059
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
2929
abstract class TextBasedFileScan(
3030
sparkSession: SparkSession,
3131
fileIndex: PartitioningAwareFileIndex,
32-
readSchema: StructType,
32+
readDataSchema: StructType,
33+
readPartitionSchema: StructType,
3334
options: CaseInsensitiveStringMap)
34-
extends FileScan(sparkSession, fileIndex, readSchema, options) {
35+
extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) {
3536
private var codecFactory: CompressionCodecFactory = _
3637

3738
override def isSplitable(path: Path): Boolean = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,21 @@ import org.apache.spark.util.SerializableConfiguration
3333
* @param sqlConf SQL configuration.
3434
* @param broadcastedConf Broadcasted serializable Hadoop Configuration.
3535
* @param dataSchema Schema of CSV files.
36+
* @param readDataSchema Required data schema in the batch scan.
3637
* @param partitionSchema Schema of partitions.
37-
* @param readSchema Required schema in the batch scan.
3838
* @param parsedOptions Options for parsing CSV files.
3939
*/
4040
case class CSVPartitionReaderFactory(
4141
sqlConf: SQLConf,
4242
broadcastedConf: Broadcast[SerializableConfiguration],
4343
dataSchema: StructType,
44+
readDataSchema: StructType,
4445
partitionSchema: StructType,
45-
readSchema: StructType,
4646
parsedOptions: CSVOptions) extends FilePartitionReaderFactory {
4747
private val columnPruning = sqlConf.csvColumnPruning
48-
private val readDataSchema =
49-
getReadDataSchema(readSchema, partitionSchema, sqlConf.caseSensitiveAnalysis)
5048

5149
override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
5250
val conf = broadcastedConf.value.value
53-
5451
val parser = new UnivocityParser(
5552
StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
5653
StructType(readDataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ case class CSVScan(
3535
sparkSession: SparkSession,
3636
fileIndex: PartitioningAwareFileIndex,
3737
dataSchema: StructType,
38-
readSchema: StructType,
38+
readDataSchema: StructType,
39+
readPartitionSchema: StructType,
3940
options: CaseInsensitiveStringMap)
40-
extends TextBasedFileScan(sparkSession, fileIndex, readSchema, options) {
41+
extends TextBasedFileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema, options) {
4142

4243
private lazy val parsedOptions: CSVOptions = new CSVOptions(
4344
options.asScala.toMap,
@@ -53,8 +54,8 @@ case class CSVScan(
5354
// Check a field requirement for corrupt records here to throw an exception in a driver side
5455
ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord)
5556

56-
if (readSchema.length == 1 &&
57-
readSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {
57+
if (readDataSchema.length == 1 &&
58+
readDataSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {
5859
throw new AnalysisException(
5960
"Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the\n" +
6061
"referenced columns only include the internal corrupt record column\n" +
@@ -72,7 +73,9 @@ case class CSVScan(
7273
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
7374
val broadcastedConf = sparkSession.sparkContext.broadcast(
7475
new SerializableConfiguration(hadoopConf))
76+
// The partition values are already truncated in `FileScan.partitions`.
77+
// We should use `readPartitionSchema` as the partition schema here.
7578
CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
76-
dataSchema, fileIndex.partitionSchema, readSchema, parsedOptions)
79+
dataSchema, readDataSchema, readPartitionSchema, parsedOptions)
7780
}
7881
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ case class CSVScanBuilder(
2929
fileIndex: PartitioningAwareFileIndex,
3030
schema: StructType,
3131
dataSchema: StructType,
32-
options: CaseInsensitiveStringMap) extends FileScanBuilder(schema) {
32+
options: CaseInsensitiveStringMap)
33+
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
3334

3435
override def build(): Scan = {
35-
CSVScan(sparkSession, fileIndex, dataSchema, readSchema, options)
36+
CSVScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options)
3637
}
3738
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,30 +46,30 @@ import org.apache.spark.util.SerializableConfiguration
4646
* @param sqlConf SQL configuration.
4747
* @param broadcastedConf Broadcast serializable Hadoop Configuration.
4848
* @param dataSchema Schema of orc files.
49+
* @param readDataSchema Required data schema in the batch scan.
4950
* @param partitionSchema Schema of partitions.
50-
* @param readSchema Required schema in the batch scan.
5151
*/
5252
case class OrcPartitionReaderFactory(
5353
sqlConf: SQLConf,
5454
broadcastedConf: Broadcast[SerializableConfiguration],
5555
dataSchema: StructType,
56-
partitionSchema: StructType,
57-
readSchema: StructType) extends FilePartitionReaderFactory {
56+
readDataSchema: StructType,
57+
partitionSchema: StructType) extends FilePartitionReaderFactory {
58+
private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields)
5859
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
5960
private val capacity = sqlConf.orcVectorizedReaderBatchSize
6061

6162
override def supportColumnarReads(partition: InputPartition): Boolean = {
6263
sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled &&
63-
readSchema.length <= sqlConf.wholeStageMaxNumFields &&
64-
readSchema.forall(_.dataType.isInstanceOf[AtomicType])
64+
resultSchema.length <= sqlConf.wholeStageMaxNumFields &&
65+
resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
6566
}
6667

6768
override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
6869
val conf = broadcastedConf.value.value
6970

70-
val readDataSchema = getReadDataSchema(readSchema, partitionSchema, isCaseSensitive)
71-
val readDataSchemaString = OrcUtils.orcTypeDescriptionString(readDataSchema)
72-
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, readDataSchemaString)
71+
val resultSchemaString = OrcUtils.orcTypeDescriptionString(resultSchema)
72+
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString)
7373
OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)
7474

7575
val filePath = new Path(new URI(file.filePath))
@@ -113,8 +113,8 @@ case class OrcPartitionReaderFactory(
113113
override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = {
114114
val conf = broadcastedConf.value.value
115115

116-
val readSchemaString = OrcUtils.orcTypeDescriptionString(readSchema)
117-
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, readSchemaString)
116+
val resultSchemaString = OrcUtils.orcTypeDescriptionString(resultSchema)
117+
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString)
118118
OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)
119119

120120
val filePath = new Path(new URI(file.filePath))
@@ -124,13 +124,13 @@ case class OrcPartitionReaderFactory(
124124
val reader = OrcFile.createReader(filePath, readerOptions)
125125

126126
val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds(
127-
isCaseSensitive, dataSchema, readSchema, reader, conf)
127+
isCaseSensitive, dataSchema, readDataSchema, reader, conf)
128128

129129
if (requestedColIdsOrEmptyFile.isEmpty) {
130130
new EmptyPartitionReader
131131
} else {
132-
val requestedColIds = requestedColIdsOrEmptyFile.get
133-
assert(requestedColIds.length == readSchema.length,
132+
val requestedColIds = requestedColIdsOrEmptyFile.get ++ Array.fill(partitionSchema.length)(-1)
133+
assert(requestedColIds.length == resultSchema.length,
134134
"[BUG] requested column IDs do not match required schema")
135135
val taskConf = new Configuration(conf)
136136

@@ -140,15 +140,12 @@ case class OrcPartitionReaderFactory(
140140

141141
val batchReader = new OrcColumnarBatchReader(capacity)
142142
batchReader.initialize(fileSplit, taskAttemptContext)
143-
val columnNameMap = partitionSchema.fields.map(
144-
PartitioningUtils.getColName(_, isCaseSensitive)).zipWithIndex.toMap
145-
val requestedPartitionColIds = readSchema.fields.map { field =>
146-
columnNameMap.getOrElse(PartitioningUtils.getColName(field, isCaseSensitive), -1)
147-
}
143+
val requestedPartitionColIds =
144+
Array.fill(readDataSchema.length)(-1) ++ Range(0, partitionSchema.length)
148145

149146
batchReader.initBatch(
150-
TypeDescription.fromString(readSchemaString),
151-
readSchema.fields,
147+
TypeDescription.fromString(resultSchemaString),
148+
resultSchema.fields,
152149
requestedColIds,
153150
requestedPartitionColIds,
154151
file.partitionValues)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,18 @@ case class OrcScan(
3232
hadoopConf: Configuration,
3333
fileIndex: PartitioningAwareFileIndex,
3434
dataSchema: StructType,
35-
readSchema: StructType,
35+
readDataSchema: StructType,
36+
readPartitionSchema: StructType,
3637
options: CaseInsensitiveStringMap)
37-
extends FileScan(sparkSession, fileIndex, readSchema, options) {
38+
extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) {
3839
override def isSplitable(path: Path): Boolean = true
3940

4041
override def createReaderFactory(): PartitionReaderFactory = {
4142
val broadcastedConf = sparkSession.sparkContext.broadcast(
4243
new SerializableConfiguration(hadoopConf))
44+
// The partition values are already truncated in `FileScan.partitions`.
45+
// We should use `readPartitionSchema` as the partition schema here.
4346
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
44-
dataSchema, fileIndex.partitionSchema, readSchema)
47+
dataSchema, readDataSchema, readPartitionSchema)
4548
}
4649
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,16 @@ case class OrcScanBuilder(
3636
schema: StructType,
3737
dataSchema: StructType,
3838
options: CaseInsensitiveStringMap)
39-
extends FileScanBuilder(schema) with SupportsPushDownFilters {
39+
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters {
4040
lazy val hadoopConf = {
4141
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
4242
// Hadoop Configurations are case sensitive.
4343
sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
4444
}
4545

4646
override def build(): Scan = {
47-
OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readSchema, options)
47+
OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema,
48+
readDataSchema(), readPartitionSchema(), options)
4849
}
4950

5051
private var _pushedFilters: Array[Filter] = Array.empty

0 commit comments

Comments
 (0)