Skip to content

Commit 2cffd18

Browse files
cloud-fanrdblue
authored andcommitted
[SPARK-24991][SQL] use InternalRow in DataSourceWriter
A follow up of apache#21118 Since we use `InternalRow` in the read API of data source v2, we should do the same thing for the write API. existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes apache#21948 from cloud-fan/row-write.
1 parent cd7f5a7 commit 2cffd18

File tree

7 files changed

+15
-141
lines changed

7 files changed

+15
-141
lines changed

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
package org.apache.spark.sql.sources.v2.writer;
1919

2020
import org.apache.spark.annotation.InterfaceStability;
21-
import org.apache.spark.sql.Row;
2221
import org.apache.spark.sql.SaveMode;
22+
import org.apache.spark.sql.catalyst.InternalRow;
2323
import org.apache.spark.sql.sources.v2.DataSourceOptions;
2424
import org.apache.spark.sql.sources.v2.WriteSupport;
2525
import org.apache.spark.sql.types.StructType;
@@ -57,7 +57,7 @@ public interface DataSourceWriter {
5757
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
5858
* submitted.
5959
*/
60-
DataWriterFactory<Row> createWriterFactory();
60+
DataWriterFactory<InternalRow> createWriterFactory();
6161

6262
/**
6363
* Returns whether Spark should use the commit coordinator to ensure that only one attempt for

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@
4949
* successfully, and have a way to revert committed data writers without the commit message, because
5050
* Spark only accepts the commit message that arrives first and ignore others.
5151
*
52-
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data
53-
* source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers
54-
* that mix in {@link SupportsWriteInternalRow}.
52+
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}.
5553
*/
5654
@InterfaceStability.Evolving
5755
public interface DataWriter<T> {

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
public interface DataWriterFactory<T> extends Serializable {
3434

3535
/**
36-
* Returns a data writer to do the actual writing work.
36+
* Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data
37+
* object instance when sending data to the data writer, for better performance. Data writers
38+
* are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a
39+
* list.
3740
*
3841
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
3942
* submitted.

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java

Lines changed: 0 additions & 41 deletions
This file was deleted.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
4747
override def output: Seq[Attribute] = Nil
4848

4949
override protected def doExecute(): RDD[InternalRow] = {
50-
val writeTask = writer match {
51-
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
52-
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
53-
}
54-
50+
val writeTask = writer.createWriterFactory()
5551
val useCommitCoordinator = writer.useCommitCoordinator
5652
val rdd = query.execute()
5753
val messages = new Array[WriterCommitMessage](rdd.partitions.length)
@@ -142,23 +138,3 @@ object DataWritingSparkTask extends Logging {
142138
}
143139
}
144140

145-
class InternalRowDataWriterFactory(
146-
rowWriterFactory: DataWriterFactory[Row],
147-
schema: StructType) extends DataWriterFactory[InternalRow] {
148-
149-
override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = {
150-
new InternalRowDataWriter(
151-
rowWriterFactory.createDataWriter(partitionId, attemptNumber),
152-
RowEncoder.apply(schema).resolveAndBind())
153-
}
154-
}
155-
156-
class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row])
157-
extends DataWriter[InternalRow] {
158-
159-
override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record))
160-
161-
override def commit(): WriterCommitMessage = rowWriter.commit()
162-
163-
override def abort(): Unit = rowWriter.abort()
164-
}

sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
188188
assert(e2.getMessage.contains("Writing job aborted"))
189189
// make sure we don't have partial data.
190190
assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty)
191-
192-
// test internal row writer
193-
spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName)
194-
.option("path", path).option("internal", "true").mode("overwrite").save()
195-
checkAnswer(
196-
spark.read.format(cls.getName).option("path", path).load(),
197-
spark.range(5).select('id, -'id))
198191
}
199192
}
200193
}

sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala

Lines changed: 7 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration
2626
import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path}
2727

2828
import org.apache.spark.SparkContext
29-
import org.apache.spark.sql.{Row, SaveMode}
29+
import org.apache.spark.sql.SaveMode
3030
import org.apache.spark.sql.catalyst.InternalRow
3131
import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader}
3232
import org.apache.spark.sql.sources.v2.writer._
@@ -65,9 +65,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
6565
}
6666

6767
class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter {
68-
override def createWriterFactory(): DataWriterFactory[Row] = {
68+
override def createWriterFactory(): DataWriterFactory[InternalRow] = {
6969
SimpleCounter.resetCounter
70-
new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf))
70+
new CSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf))
7171
}
7272

7373
override def onDataWriterCommit(message: WriterCommitMessage): Unit = {
@@ -97,18 +97,6 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
9797
}
9898
}
9999

100-
class InternalRowWriter(jobId: String, path: String, conf: Configuration)
101-
extends Writer(jobId, path, conf) with SupportsWriteInternalRow {
102-
103-
override def createWriterFactory(): DataWriterFactory[Row] = {
104-
throw new IllegalArgumentException("not expected!")
105-
}
106-
107-
override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = {
108-
new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf))
109-
}
110-
}
111-
112100
override def createReader(options: DataSourceOptions): DataSourceReader = {
113101
val path = new Path(options.get("path").get())
114102
val conf = SparkContext.getActive.get.hadoopConfiguration
@@ -124,7 +112,6 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
124112
assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false))
125113

126114
val path = new Path(options.get("path").get())
127-
val internal = options.get("internal").isPresent
128115
val conf = SparkContext.getActive.get.hadoopConfiguration
129116
val fs = path.getFileSystem(conf)
130117

@@ -142,17 +129,8 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
142129
fs.delete(path, true)
143130
}
144131

145-
Optional.of(createWriter(jobId, path, conf, internal))
146-
}
147-
148-
private def createWriter(
149-
jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceWriter = {
150132
val pathStr = path.toUri.toString
151-
if (internal) {
152-
new InternalRowWriter(jobId, pathStr, conf)
153-
} else {
154-
new Writer(jobId, pathStr, conf)
155-
}
133+
Optional.of(new Writer(jobId, pathStr, conf))
156134
}
157135
}
158136

@@ -204,51 +182,18 @@ private[v2] object SimpleCounter {
204182
}
205183
}
206184

207-
class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration)
208-
extends DataWriterFactory[Row] {
209-
210-
override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = {
211-
val jobPath = new Path(new Path(path, "_temporary"), jobId)
212-
val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber")
213-
val fs = filePath.getFileSystem(conf.value)
214-
new SimpleCSVDataWriter(fs, filePath)
215-
}
216-
}
217-
218-
class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] {
219-
220-
private val out = fs.create(file)
221-
222-
override def write(record: Row): Unit = {
223-
out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n")
224-
}
225-
226-
override def commit(): WriterCommitMessage = {
227-
out.close()
228-
null
229-
}
230-
231-
override def abort(): Unit = {
232-
try {
233-
out.close()
234-
} finally {
235-
fs.delete(file, false)
236-
}
237-
}
238-
}
239-
240-
class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration)
185+
class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration)
241186
extends DataWriterFactory[InternalRow] {
242187

243188
override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = {
244189
val jobPath = new Path(new Path(path, "_temporary"), jobId)
245190
val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber")
246191
val fs = filePath.getFileSystem(conf.value)
247-
new InternalRowCSVDataWriter(fs, filePath)
192+
new CSVDataWriter(fs, filePath)
248193
}
249194
}
250195

251-
class InternalRowCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] {
196+
class CSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] {
252197

253198
private val out = fs.create(file)
254199

0 commit comments

Comments
 (0)