Skip to content

Commit 0827fb9

Browse files
cloud-fanotterc
authored andcommitted
[SPARK-24991][SQL] use InternalRow in DataSourceWriter
A follow up of apache#21118 Since we use `InternalRow` in the read API of data source v2, we should do the same thing for the write API. existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes apache#21948 from cloud-fan/row-write. Ref: LIHADOOP-48531 RB=1855948 G=superfriends-reviewers R=yezhou,latang,mshen,fli,zolin A=
1 parent 2e736f0 commit 0827fb9

File tree

17 files changed

+84
-286
lines changed

17 files changed

+84
-286
lines changed

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ case object KafkaWriterCommitMessage extends WriterCommitMessage
4242
*/
4343
class KafkaStreamWriter(
4444
topic: Option[String], producerParams: Map[String, String], schema: StructType)
45-
extends StreamWriter with SupportsWriteInternalRow {
45+
extends StreamWriter {
4646

4747
validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic)
4848

49-
override def createInternalRowWriterFactory(): KafkaStreamWriterFactory =
49+
override def createWriterFactory(): KafkaStreamWriterFactory =
5050
KafkaStreamWriterFactory(topic, producerParams, schema)
5151

5252
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
package org.apache.spark.sql.sources.v2.writer;
1919

2020
import org.apache.spark.annotation.InterfaceStability;
21-
import org.apache.spark.sql.Row;
2221
import org.apache.spark.sql.SaveMode;
22+
import org.apache.spark.sql.catalyst.InternalRow;
2323
import org.apache.spark.sql.sources.v2.DataSourceOptions;
2424
import org.apache.spark.sql.sources.v2.StreamWriteSupport;
2525
import org.apache.spark.sql.sources.v2.WriteSupport;
@@ -61,7 +61,7 @@ public interface DataSourceWriter {
6161
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
6262
* submitted.
6363
*/
64-
DataWriterFactory<Row> createWriterFactory();
64+
DataWriterFactory<InternalRow> createWriterFactory();
6565

6666
/**
6767
* Returns whether Spark should use the commit coordinator to ensure that at most one task for

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@
5353
* successfully, and have a way to revert committed data writers without the commit message, because
5454
* Spark only accepts the commit message that arrives first and ignore others.
5555
*
56-
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data
57-
* source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers
58-
* that mix in {@link SupportsWriteInternalRow}.
56+
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}.
5957
*/
6058
@InterfaceStability.Evolving
6159
public interface DataWriter<T> {

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
public interface DataWriterFactory<T> extends Serializable {
3434

3535
/**
36-
* Returns a data writer to do the actual writing work.
36+
* Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data
37+
* object instance when sending data to the data writer, for better performance. Data writers
38+
* are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a
39+
* list.
3740
*
3841
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
3942
* submitted.

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java

Lines changed: 0 additions & 41 deletions
This file was deleted.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
5050
override def output: Seq[Attribute] = Nil
5151

5252
override protected def doExecute(): RDD[InternalRow] = {
53-
val writeTask = writer match {
54-
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
55-
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
56-
}
57-
53+
val writeTask = writer.createWriterFactory()
5854
val useCommitCoordinator = writer.useCommitCoordinator
5955
val rdd = query.execute()
6056
val messages = new Array[WriterCommitMessage](rdd.partitions.length)
@@ -155,27 +151,3 @@ object DataWritingSparkTask extends Logging {
155151
})
156152
}
157153
}
158-
159-
class InternalRowDataWriterFactory(
160-
rowWriterFactory: DataWriterFactory[Row],
161-
schema: StructType) extends DataWriterFactory[InternalRow] {
162-
163-
override def createDataWriter(
164-
partitionId: Int,
165-
taskId: Long,
166-
epochId: Long): DataWriter[InternalRow] = {
167-
new InternalRowDataWriter(
168-
rowWriterFactory.createDataWriter(partitionId, taskId, epochId),
169-
RowEncoder.apply(schema).resolveAndBind())
170-
}
171-
}
172-
173-
class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row])
174-
extends DataWriter[InternalRow] {
175-
176-
override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record))
177-
178-
override def commit(): WriterCommitMessage = rowWriter.commit()
179-
180-
override def abort(): Unit = rowWriter.abort()
181-
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Curre
2828
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
2929
import org.apache.spark.sql.execution.SQLExecution
3030
import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
31-
import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter}
31+
import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter
3232
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport}
3333
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
34-
import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow
3534
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
3635
import org.apache.spark.util.{Clock, Utils}
3736

@@ -477,12 +476,7 @@ class MicroBatchExecution(
477476
newAttributePlan.schema,
478477
outputMode,
479478
new DataSourceOptions(extraOptions.asJava))
480-
if (writer.isInstanceOf[SupportsWriteInternalRow]) {
481-
WriteToDataSourceV2(
482-
new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan)
483-
} else {
484-
WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan)
485-
}
479+
WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan)
486480
case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
487481
}
488482

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
4444
SparkEnv.get)
4545
EpochTracker.initializeCurrentEpoch(
4646
context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)
47-
4847
while (!context.isInterrupted() && !context.isCompleted()) {
4948
var dataWriter: DataWriter[InternalRow] = null
5049
// write the data and commit this writer.

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala

Lines changed: 7 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,14 @@ package org.apache.spark.sql.execution.streaming.continuous
1919

2020
import scala.util.control.NonFatal
2121

22-
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
22+
import org.apache.spark.SparkException
2323
import org.apache.spark.internal.Logging
2424
import org.apache.spark.rdd.RDD
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.expressions.Attribute
2727
import org.apache.spark.sql.execution.SparkPlan
28-
import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory}
29-
import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
3028
import org.apache.spark.sql.execution.streaming.StreamExecution
31-
import org.apache.spark.sql.sources.v2.writer._
3229
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
33-
import org.apache.spark.util.Utils
3430

3531
/**
3632
* The physical plan for writing data into a continuous processing [[StreamWriter]].
@@ -41,29 +37,20 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla
4137
override def output: Seq[Attribute] = Nil
4238

4339
override protected def doExecute(): RDD[InternalRow] = {
44-
val writerFactory = writer match {
45-
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
46-
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
47-
}
48-
49-
val rdd = query.execute()
40+
val writerFactory = writer.createWriterFactory()
41+
val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)
5042

5143
logInfo(s"Start processing data source writer: $writer. " +
52-
s"The input RDD has ${rdd.getNumPartitions} partitions.")
53-
// Let the epoch coordinator know how many partitions the write RDD has.
44+
s"The input RDD has ${rdd.partitions.length} partitions.")
5445
EpochCoordinatorRef.get(
55-
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
56-
sparkContext.env)
46+
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
47+
sparkContext.env)
5748
.askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))
5849

5950
try {
6051
// Force the RDD to run so continuous processing starts; no data is actually being collected
6152
// to the driver, as ContinuousWriteRDD outputs nothing.
62-
sparkContext.runJob(
63-
rdd,
64-
(context: TaskContext, iter: Iterator[InternalRow]) =>
65-
WriteToContinuousDataSourceExec.run(writerFactory, context, iter),
66-
rdd.partitions.indices)
53+
rdd.collect()
6754
} catch {
6855
case _: InterruptedException =>
6956
// Interruption is how continuous queries are ended, so accept and ignore the exception.
@@ -80,45 +67,3 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla
8067
sparkContext.emptyRDD
8168
}
8269
}
83-
84-
object WriteToContinuousDataSourceExec extends Logging {
85-
def run(
86-
writeTask: DataWriterFactory[InternalRow],
87-
context: TaskContext,
88-
iter: Iterator[InternalRow]): Unit = {
89-
val epochCoordinator = EpochCoordinatorRef.get(
90-
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
91-
SparkEnv.get)
92-
var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
93-
94-
do {
95-
var dataWriter: DataWriter[InternalRow] = null
96-
// write the data and commit this writer.
97-
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
98-
try {
99-
dataWriter = writeTask.createDataWriter(
100-
context.partitionId(), context.attemptNumber(), currentEpoch)
101-
while (iter.hasNext) {
102-
dataWriter.write(iter.next())
103-
}
104-
logInfo(s"Writer for partition ${context.partitionId()} is committing.")
105-
val msg = dataWriter.commit()
106-
logInfo(s"Writer for partition ${context.partitionId()} committed.")
107-
epochCoordinator.send(
108-
CommitPartitionEpoch(context.partitionId(), currentEpoch, msg)
109-
)
110-
currentEpoch += 1
111-
} catch {
112-
case _: InterruptedException =>
113-
// Continuous shutdown always involves an interrupt. Just finish the task.
114-
}
115-
})(catchBlock = {
116-
// If there is an error, abort this writer. We enter this callback in the middle of
117-
// rethrowing an exception, so runContinuous will stop executing at this point.
118-
logError(s"Writer for partition ${context.partitionId()} is aborting.")
119-
if (dataWriter != null) dataWriter.abort()
120-
logError(s"Writer for partition ${context.partitionId()} aborted.")
121-
})
122-
} while (!context.isInterrupted())
123-
}
124-
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
package org.apache.spark.sql.execution.streaming.sources
1919

20-
import scala.collection.JavaConverters._
21-
2220
import org.apache.spark.internal.Logging
23-
import org.apache.spark.sql.{Row, SparkSession}
21+
import org.apache.spark.sql.{Dataset, SparkSession}
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
2424
import org.apache.spark.sql.sources.v2.DataSourceOptions
2525
import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage}
2626
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
@@ -39,7 +39,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions)
3939
assert(SparkSession.getActiveSession.isDefined)
4040
protected val spark = SparkSession.getActiveSession.get
4141

42-
def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory
42+
def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory
4343

4444
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
4545
// We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2
@@ -62,8 +62,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions)
6262
println(printMessage)
6363
println("-------------------------------------------")
6464
// scalastyle:off println
65-
spark
66-
.createDataFrame(rows.toList.asJava, schema)
65+
Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows))
6766
.show(numRowsToShow, isTruncated)
6867
}
6968

0 commit comments

Comments
 (0)