Skip to content

Commit 2ffa817

Browse files
ulysses-youcloud-fan
authored andcommitted
[SPARK-41407][SQL] Pull out v1 write to WriteFiles
### What changes were proposed in this pull request? This pr aims to pull out the details of v1 write files to a new operator `WriteFiles`(logical) `WriteFilesExec`(physical). Then we can make v1 write files support whole stage codegen in future. Introduce `WriteFilesSpec` to hold all v1 write files information: ```scala case class WriteFilesSpec( description: WriteJobDescription, committer: FileCommitProtocol, concurrentOutputWriterSpecFunc: SparkPlan => Option[ConcurrentOutputWriterSpec]) extends WriteSpec ``` In order to compatiable with existed code path, this pr adds a new method `executeWrite` in `SparkPlan`: ```scala def executeWrite(writeSpec: WriteSpec): RDD[WriterCommitMessage] ``` Refactor `FileFormatWriter` to make write files clearly: - execute write using old code path - execute write using `SparkPlan.executeWrite` - extract `writeAndCommit` method to work with both two code path ### Why are the changes needed? This is the preparation work before support v1 write whole stage codegen. ### Does this PR introduce _any_ user-facing change? for user, no for developer, yes: - add a new method `executeWrite` in `SparkPlan` - add a new interface `WriteSpec` ### How was this patch tested? pass CI with spark.sql.optimizer.plannedWrite.enabled on/off Closes #38939 from ulysses-you/v1write-plan. Authored-by: ulysses-you <ulyssesyou18@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent a1c727f commit 2ffa817

File tree

11 files changed

+334
-53
lines changed

11 files changed

+334
-53
lines changed

core/src/main/resources/error/error-classes.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3693,7 +3693,7 @@
36933693
},
36943694
"_LEGACY_ERROR_TEMP_2054" : {
36953695
"message" : [
3696-
"Task failed while writing rows."
3696+
"Task failed while writing rows. <message>"
36973697
]
36983698
},
36993699
"_LEGACY_ERROR_TEMP_2055" : {

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
785785
def taskFailedWhileWritingRowsError(cause: Throwable): Throwable = {
786786
new SparkException(
787787
errorClass = "_LEGACY_ERROR_TEMP_2054",
788-
messageParameters = Map.empty,
788+
messageParameters = Map("message" -> cause.getMessage),
789789
cause = cause)
790790
}
791791

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.internal;
19+
20+
import java.io.Serializable;
21+
22+
/**
23+
* Write spec is a input parameter of
24+
* {@link org.apache.spark.sql.execution.SparkPlan#executeWrite}.
25+
*
26+
* <p>
27+
* This is an empty interface, the concrete class which implements
28+
* {@link org.apache.spark.sql.execution.SparkPlan#doExecuteWrite}
29+
* should define its own class and use it.
30+
*
31+
* @since 3.4.0
32+
*/
33+
public interface WriteSpec extends Serializable {}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
3434
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3535
import org.apache.spark.sql.catalyst.plans.physical._
3636
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, UnaryLike}
37+
import org.apache.spark.sql.connector.write.WriterCommitMessage
3738
import org.apache.spark.sql.errors.QueryExecutionErrors
3839
import org.apache.spark.sql.execution.metric.SQLMetric
39-
import org.apache.spark.sql.internal.SQLConf
40+
import org.apache.spark.sql.internal.{SQLConf, WriteSpec}
4041
import org.apache.spark.sql.vectorized.ColumnarBatch
4142
import org.apache.spark.util.NextIterator
4243
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
@@ -223,6 +224,19 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
223224
doExecuteColumnar()
224225
}
225226

227+
/**
228+
* Returns the result of writes as an RDD[WriterCommitMessage] variable by delegating to
229+
* `doExecuteWrite` after preparations.
230+
*
231+
* Concrete implementations of SparkPlan should override `doExecuteWrite`.
232+
*/
233+
def executeWrite(writeSpec: WriteSpec): RDD[WriterCommitMessage] = executeQuery {
234+
if (isCanonicalizedPlan) {
235+
throw SparkException.internalError("A canonicalized plan is not supposed to be executed.")
236+
}
237+
doExecuteWrite(writeSpec)
238+
}
239+
226240
/**
227241
* Executes a query after preparing the query and adding query plan information to created RDDs
228242
* for visualization.
@@ -324,6 +338,16 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
324338
s" mismatch:\n${this}")
325339
}
326340

341+
/**
342+
* Produces the result of the writes as an `RDD[WriterCommitMessage]`
343+
*
344+
* Overridden by concrete implementations of SparkPlan.
345+
*/
346+
protected def doExecuteWrite(writeSpec: WriteSpec): RDD[WriterCommitMessage] = {
347+
throw SparkException.internalError(s"Internal Error ${this.getClass} has write support" +
348+
s" mismatch:\n${this}")
349+
}
350+
327351
/**
328352
* Converts the output of this plan to row-based if it is columnar plan.
329353
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors
3434
import org.apache.spark.sql.execution.aggregate.AggUtils
3535
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
3636
import org.apache.spark.sql.execution.command._
37+
import org.apache.spark.sql.execution.datasources.{WriteFiles, WriteFilesExec}
3738
import org.apache.spark.sql.execution.exchange.{REBALANCE_PARTITIONS_BY_COL, REBALANCE_PARTITIONS_BY_NONE, REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeExec}
3839
import org.apache.spark.sql.execution.python._
3940
import org.apache.spark.sql.execution.streaming._
@@ -894,6 +895,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
894895
throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("MERGE INTO TABLE")
895896
case logical.CollectMetrics(name, metrics, child) =>
896897
execution.CollectMetricsExec(name, metrics, planLater(child)) :: Nil
898+
case WriteFiles(child) =>
899+
WriteFilesExec(planLater(child)) :: Nil
897900
case _ => Nil
898901
}
899902
}

sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ case class CreateDataSourceTableAsSelectCommand(
145145
outputColumnNames: Seq[String])
146146
extends V1WriteCommand {
147147

148+
override def fileFormatProvider: Boolean = {
149+
table.provider.forall { provider =>
150+
classOf[FileFormat].isAssignableFrom(DataSource.providingClass(provider, conf))
151+
}
152+
}
153+
148154
override lazy val partitionColumns: Seq[Attribute] = {
149155
val unresolvedPartitionColumns = table.partitionColumnNames.map(UnresolvedAttribute.quoted)
150156
DataSource.resolvePartitionColumns(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,8 @@ case class DataSource(
9797

9898
case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String])
9999

100-
lazy val providingClass: Class[_] = {
101-
val cls = DataSource.lookupDataSource(className, sparkSession.sessionState.conf)
102-
// `providingClass` is used for resolving data source relation for catalog tables.
103-
// As now catalog for data source V2 is under development, here we fall back all the
104-
// [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog works.
105-
// [[FileDataSourceV2]] will still be used if we call the load()/save() method in
106-
// [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource`
107-
// instead of `providingClass`.
108-
cls.newInstance() match {
109-
case f: FileDataSourceV2 => f.fallbackFileFormat
110-
case _ => cls
111-
}
112-
}
100+
lazy val providingClass: Class[_] =
101+
DataSource.providingClass(className, sparkSession.sessionState.conf)
113102

114103
private[sql] def providingInstance(): Any = providingClass.getConstructor().newInstance()
115104

@@ -843,4 +832,18 @@ object DataSource extends Logging {
843832
}
844833
}
845834
}
835+
836+
def providingClass(className: String, conf: SQLConf): Class[_] = {
837+
val cls = DataSource.lookupDataSource(className, conf)
838+
// `providingClass` is used for resolving data source relation for catalog tables.
839+
// As now catalog for data source V2 is under development, here we fall back all the
840+
// [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog works.
841+
// [[FileDataSourceV2]] will still be used if we call the load()/save() method in
842+
// [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource`
843+
// instead of `providingClass`.
844+
cls.newInstance() match {
845+
case f: FileDataSourceV2 => f.fallbackFileFormat
846+
case _ => cls
847+
}
848+
}
846849
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala

Lines changed: 122 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3636
import org.apache.spark.sql.catalyst.expressions._
3737
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
3838
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
39+
import org.apache.spark.sql.connector.write.WriterCommitMessage
3940
import org.apache.spark.sql.errors.QueryExecutionErrors
4041
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter}
4142
import org.apache.spark.sql.internal.SQLConf
@@ -103,14 +104,6 @@ object FileFormatWriter extends Logging {
103104
.map(FileSourceMetadataAttribute.cleanupFileSourceMetadataInformation))
104105
val dataColumns = finalOutputSpec.outputColumns.filterNot(partitionSet.contains)
105106

106-
val hasEmpty2Null = plan.exists(p => V1WritesUtils.hasEmptyToNull(p.expressions))
107-
val empty2NullPlan = if (hasEmpty2Null) {
108-
plan
109-
} else {
110-
val projectList = V1WritesUtils.convertEmptyToNull(plan.output, partitionColumns)
111-
if (projectList.nonEmpty) ProjectExec(projectList, plan) else plan
112-
}
113-
114107
val writerBucketSpec = V1WritesUtils.getWriterBucketSpec(bucketSpec, dataColumns, options)
115108
val sortColumns = V1WritesUtils.getBucketSortColumns(bucketSpec, dataColumns)
116109

@@ -144,9 +137,10 @@ object FileFormatWriter extends Logging {
144137
// columns.
145138
val requiredOrdering = partitionColumns.drop(numStaticPartitionCols) ++
146139
writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns
140+
val writeFilesOpt = V1WritesUtils.getWriteFilesOpt(plan)
147141
// the sort order doesn't matter
148142
// Use the output ordering from the original plan before adding the empty2null projection.
149-
val actualOrdering = plan.outputOrdering.map(_.child)
143+
val actualOrdering = writeFilesOpt.map(_.child).getOrElse(plan).outputOrdering.map(_.child)
150144
val orderingMatched = V1WritesUtils.isOrderingMatched(requiredOrdering, actualOrdering)
151145

152146
SQLExecution.checkSQLExecutionId(sparkSession)
@@ -155,10 +149,6 @@ object FileFormatWriter extends Logging {
155149
// get an ID guaranteed to be unique.
156150
job.getConfiguration.set("spark.sql.sources.writeJobUUID", description.uuid)
157151

158-
// This call shouldn't be put into the `try` block below because it only initializes and
159-
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
160-
committer.setupJob(job)
161-
162152
// When `PLANNED_WRITE_ENABLED` is true, the optimizer rule V1Writes will add logical sort
163153
// operator based on the required ordering of the V1 write command. So the output
164154
// ordering of the physical plan should always match the required ordering. Here
@@ -169,27 +159,55 @@ object FileFormatWriter extends Logging {
169159
// V1 write command will be empty).
170160
if (Utils.isTesting) outputOrderingMatched = orderingMatched
171161

172-
try {
162+
if (writeFilesOpt.isDefined) {
163+
// build `WriteFilesSpec` for `WriteFiles`
164+
val concurrentOutputWriterSpecFunc = (plan: SparkPlan) => {
165+
val sortPlan = createSortPlan(plan, requiredOrdering, outputSpec)
166+
createConcurrentOutputWriterSpec(sparkSession, sortPlan, sortColumns)
167+
}
168+
val writeSpec = WriteFilesSpec(
169+
description = description,
170+
committer = committer,
171+
concurrentOutputWriterSpecFunc = concurrentOutputWriterSpecFunc
172+
)
173+
executeWrite(sparkSession, plan, writeSpec, job)
174+
} else {
175+
executeWrite(sparkSession, plan, job, description, committer, outputSpec,
176+
requiredOrdering, partitionColumns, sortColumns, orderingMatched)
177+
}
178+
}
179+
// scalastyle:on argcount
180+
181+
private def executeWrite(
182+
sparkSession: SparkSession,
183+
plan: SparkPlan,
184+
job: Job,
185+
description: WriteJobDescription,
186+
committer: FileCommitProtocol,
187+
outputSpec: OutputSpec,
188+
requiredOrdering: Seq[Expression],
189+
partitionColumns: Seq[Attribute],
190+
sortColumns: Seq[Attribute],
191+
orderingMatched: Boolean): Set[String] = {
192+
val hasEmpty2Null = plan.exists(p => V1WritesUtils.hasEmptyToNull(p.expressions))
193+
val empty2NullPlan = if (hasEmpty2Null) {
194+
plan
195+
} else {
196+
val projectList = V1WritesUtils.convertEmptyToNull(plan.output, partitionColumns)
197+
if (projectList.nonEmpty) ProjectExec(projectList, plan) else plan
198+
}
199+
200+
writeAndCommit(job, description, committer) {
173201
val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) {
174202
(empty2NullPlan.execute(), None)
175203
} else {
176-
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
177-
// the physical plan may have different attribute ids due to optimizer removing some
178-
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
179-
val orderingExpr = bindReferences(
180-
requiredOrdering.map(SortOrder(_, Ascending)), finalOutputSpec.outputColumns)
181-
val sortPlan = SortExec(
182-
orderingExpr,
183-
global = false,
184-
child = empty2NullPlan)
185-
186-
val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters
187-
val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty
188-
if (concurrentWritersEnabled) {
189-
(empty2NullPlan.execute(),
190-
Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter())))
204+
val sortPlan = createSortPlan(empty2NullPlan, requiredOrdering, outputSpec)
205+
val concurrentOutputWriterSpec = createConcurrentOutputWriterSpec(
206+
sparkSession, sortPlan, sortColumns)
207+
if (concurrentOutputWriterSpec.isDefined) {
208+
(empty2NullPlan.execute(), concurrentOutputWriterSpec)
191209
} else {
192-
(sortPlan.execute(), None)
210+
(sortPlan.execute(), concurrentOutputWriterSpec)
193211
}
194212
}
195213

@@ -221,7 +239,19 @@ object FileFormatWriter extends Logging {
221239
committer.onTaskCommit(res.commitMsg)
222240
ret(index) = res
223241
})
242+
ret
243+
}
244+
}
224245

246+
private def writeAndCommit(
247+
job: Job,
248+
description: WriteJobDescription,
249+
committer: FileCommitProtocol)(f: => Array[WriteTaskResult]): Set[String] = {
250+
// This call shouldn't be put into the `try` block below because it only initializes and
251+
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
252+
committer.setupJob(job)
253+
try {
254+
val ret = f
225255
val commitMsgs = ret.map(_.commitMsg)
226256

227257
logInfo(s"Start to commit write Job ${description.uuid}.")
@@ -239,10 +269,70 @@ object FileFormatWriter extends Logging {
239269
throw cause
240270
}
241271
}
242-
// scalastyle:on argcount
272+
273+
/**
274+
* Write files using [[SparkPlan.executeWrite]]
275+
*/
276+
private def executeWrite(
277+
session: SparkSession,
278+
planForWrites: SparkPlan,
279+
writeFilesSpec: WriteFilesSpec,
280+
job: Job): Set[String] = {
281+
val committer = writeFilesSpec.committer
282+
val description = writeFilesSpec.description
283+
284+
writeAndCommit(job, description, committer) {
285+
val rdd = planForWrites.executeWrite(writeFilesSpec)
286+
val ret = new Array[WriteTaskResult](rdd.partitions.length)
287+
session.sparkContext.runJob(
288+
rdd,
289+
(context: TaskContext, iter: Iterator[WriterCommitMessage]) => {
290+
assert(iter.hasNext)
291+
val commitMessage = iter.next()
292+
assert(!iter.hasNext)
293+
commitMessage
294+
},
295+
rdd.partitions.indices,
296+
(index, res: WriterCommitMessage) => {
297+
assert(res.isInstanceOf[WriteTaskResult])
298+
val writeTaskResult = res.asInstanceOf[WriteTaskResult]
299+
committer.onTaskCommit(writeTaskResult.commitMsg)
300+
ret(index) = writeTaskResult
301+
})
302+
ret
303+
}
304+
}
305+
306+
private def createSortPlan(
307+
plan: SparkPlan,
308+
requiredOrdering: Seq[Expression],
309+
outputSpec: OutputSpec): SortExec = {
310+
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
311+
// the physical plan may have different attribute ids due to optimizer removing some
312+
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
313+
val orderingExpr = bindReferences(
314+
requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns)
315+
SortExec(
316+
orderingExpr,
317+
global = false,
318+
child = plan)
319+
}
320+
321+
private def createConcurrentOutputWriterSpec(
322+
sparkSession: SparkSession,
323+
sortPlan: SortExec,
324+
sortColumns: Seq[Attribute]): Option[ConcurrentOutputWriterSpec] = {
325+
val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters
326+
val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty
327+
if (concurrentWritersEnabled) {
328+
Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter()))
329+
} else {
330+
None
331+
}
332+
}
243333

244334
/** Writes data out in a single Spark task. */
245-
private def executeTask(
335+
private[spark] def executeTask(
246336
description: WriteJobDescription,
247337
jobTrackerID: String,
248338
sparkStageId: Int,

0 commit comments

Comments
 (0)