Skip to content

Commit 522c24e

Browse files
committed
Adds OutputWriterFactory
1 parent 047d40d commit 522c24e

File tree

4 files changed

+59
-64
lines changed

4 files changed

+59
-64
lines changed

sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,10 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider {
5454
}
5555

5656
// NOTE: This class is instantiated and used on executor side only, no need to be serializable.
57-
private[sql] class ParquetOutputWriter extends OutputWriter {
58-
private var recordWriter: RecordWriter[Void, Row] = _
59-
private var taskAttemptContext: TaskAttemptContext = _
60-
61-
override def init(
62-
path: String,
63-
dataSchema: StructType,
64-
context: TaskAttemptContext): Unit = {
57+
private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext)
58+
extends OutputWriter {
59+
60+
private val recordWriter: RecordWriter[Void, Row] = {
6561
val conf = context.getConfiguration
6662
val outputFormat = {
6763
// When appending new Parquet files to an existing Parquet file directory, to avoid
@@ -86,9 +82,8 @@ private[sql] class ParquetOutputWriter extends OutputWriter {
8682
case name if name.startsWith("_") => 0
8783
case name if name.startsWith(".") => 0
8884
case name => sys.error(
89-
s"""Trying to write Parquet files to directory $outputPath,
90-
|but found items with illegal name "$name"
91-
""".stripMargin.replace('\n', ' ').trim)
85+
s"Trying to write Parquet files to directory $outputPath, " +
86+
s"but found items with illegal name '$name'.")
9287
}.reduceOption(_ max _).getOrElse(0)
9388
} else {
9489
0
@@ -111,13 +106,12 @@ private[sql] class ParquetOutputWriter extends OutputWriter {
111106
}
112107
}
113108

114-
recordWriter = outputFormat.getRecordWriter(context)
115-
taskAttemptContext = context
109+
outputFormat.getRecordWriter(context)
116110
}
117111

118112
override def write(row: Row): Unit = recordWriter.write(null, row)
119113

120-
override def close(): Unit = recordWriter.close(taskAttemptContext)
114+
override def close(): Unit = recordWriter.close(context)
121115
}
122116

123117
private[sql] class ParquetRelation2(
@@ -175,8 +169,6 @@ private[sql] class ParquetRelation2(
175169
}
176170
}
177171

178-
override def outputWriterClass: Class[_ <: OutputWriter] = classOf[ParquetOutputWriter]
179-
180172
override def dataSchema: StructType = metadataCache.dataSchema
181173

182174
override private[sql] def refresh(): Unit = {
@@ -189,7 +181,7 @@ private[sql] class ParquetRelation2(
189181

190182
override val sizeInBytes = metadataCache.dataStatuses.map(_.getLen).sum
191183

192-
override def prepareJobForWrite(job: Job): Unit = {
184+
override def prepareJobForWrite(job: Job): OutputWriterFactory = {
193185
val conf = ContextUtil.getConfiguration(job)
194186

195187
val committerClass =
@@ -224,6 +216,13 @@ private[sql] class ParquetRelation2(
224216
.getOrElse(
225217
sqlContext.conf.parquetCompressionCodec.toUpperCase,
226218
CompressionCodecName.UNCOMPRESSED).name())
219+
220+
new OutputWriterFactory {
221+
override def newInstance(
222+
path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = {
223+
new ParquetOutputWriter(path, context)
224+
}
225+
}
227226
}
228227

229228
override def buildScan(

sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -261,15 +261,15 @@ private[sql] abstract class BaseWriterContainer(
261261

262262
protected val dataSchema = relation.dataSchema
263263

264-
protected val outputWriterClass: Class[_ <: OutputWriter] = relation.outputWriterClass
264+
protected var outputWriterFactory: OutputWriterFactory = _
265265

266266
private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _
267267

268268
def driverSideSetup(): Unit = {
269269
setupIDs(0, 0, 0)
270270
setupConf()
271271
taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
272-
relation.prepareJobForWrite(job)
272+
outputWriterFactory = relation.prepareJobForWrite(job)
273273
outputFormatClass = job.getOutputFormatClass
274274
outputCommitter = newOutputCommitter(taskAttemptContext)
275275
outputCommitter.setupJob(jobContext)
@@ -353,9 +353,8 @@ private[sql] class DefaultWriterContainer(
353353
@transient private var writer: OutputWriter = _
354354

355355
override protected def initWriters(): Unit = {
356-
writer = outputWriterClass.newInstance()
357356
taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath)
358-
writer.init(getWorkPath, dataSchema, taskAttemptContext)
357+
writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext)
359358
}
360359

361360
override def outputWriterForRow(row: Row): OutputWriter = writer
@@ -398,12 +397,10 @@ private[sql] class DynamicPartitionWriterContainer(
398397

399398
outputWriters.getOrElseUpdate(partitionPath, {
400399
val path = new Path(getWorkPath, partitionPath)
401-
val writer = outputWriterClass.newInstance()
402400
taskAttemptContext.getConfiguration.set(
403401
"spark.sql.sources.output.path",
404402
new Path(outputPath, partitionPath).toString)
405-
writer.init(path.toString, dataSchema, taskAttemptContext)
406-
writer
403+
outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext)
407404
})
408405
}
409406

sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -280,33 +280,42 @@ trait CatalystScan {
280280

281281
/**
282282
* ::Experimental::
283-
* [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the
284-
* underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor.
285-
* An [[OutputWriter]] instance is created and initialized when a new output file is opened on
286-
* executor side. This instance is used to persist rows to this single output file.
283+
* A factory that produces [[OutputWriter]]s. A new [[OutputWriterFactory]] is created on driver
284+
* side for each write job issued when writing to a [[HadoopFsRelation]], and then gets serialized
285+
* to executor side to create actual [[OutputWriter]]s on the fly.
287286
*
288287
* @since 1.4.0
289288
*/
290289
@Experimental
291-
abstract class OutputWriter {
290+
trait OutputWriterFactory extends Serializable {
292291
/**
293-
* Initializes this [[OutputWriter]] before any rows are persisted.
292+
* When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side
293+
* to instantiate new [[OutputWriter]]s.
294294
*
295295
* @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that
296296
* this may not point to the final output file. For example, `FileOutputFormat` writes to
297297
* temporary directories and then merge written files back to the final destination. In
298298
* this case, `path` points to a temporary output file under the temporary directory.
299299
* @param dataSchema Schema of the rows to be written. Partition columns are not included in the
300-
* schema if the corresponding relation is partitioned.
300+
* schema if the relation being written is partitioned.
301301
* @param context The Hadoop MapReduce task context.
302302
*
303303
* @since 1.4.0
304304
*/
305-
def init(
306-
path: String,
307-
dataSchema: StructType,
308-
context: TaskAttemptContext): Unit = ()
305+
def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter
306+
}
309307

308+
/**
309+
* ::Experimental::
310+
* [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the
311+
* underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor.
312+
* An [[OutputWriter]] instance is created and initialized when a new output file is opened on
313+
* executor side. This instance is used to persist rows to this single output file.
314+
*
315+
* @since 1.4.0
316+
*/
317+
@Experimental
318+
abstract class OutputWriter {
310319
/**
311320
* Persists a single row. Invoked on the executor side. When writing to dynamically partitioned
312321
* tables, dynamic partition columns are not included in rows to be written.
@@ -333,8 +342,8 @@ abstract class OutputWriter {
333342
* filter using selected predicates before producing an RDD containing all matching tuples as
334343
* [[Row]] objects. In addition, when reading from Hive style partitioned tables stored in file
335344
* systems, it's able to discover partitioning information from the paths of input directories, and
336-
* perform partition pruning before start reading the data. Subclasses of [[HadoopFsRelation()]] must
337-
* override one of the three `buildScan` methods to implement the read path.
345+
* perform partition pruning before start reading the data. Subclasses of [[HadoopFsRelation()]]
346+
* must override one of the three `buildScan` methods to implement the read path.
338347
*
339348
* For the write path, it provides the ability to write to both non-partitioned and partitioned
340349
* tables. Directory layout of the partitioned tables is compatible with Hive.
@@ -520,22 +529,14 @@ abstract class HadoopFsRelation private[sql](
520529
}
521530

522531
/**
523-
* Client side preparation for data writing can be put here. For example, user defined output
524-
* committer can be configured here.
532+
* Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can
533+
* be put here. For example, user defined output committer can be configured here.
525534
*
526535
* Note that the only side effect expected here is mutating `job` via its setters. Especially,
527536
* Spark SQL caches [[BaseRelation]] instances for performance, mutating relation internal states
528537
* may cause unexpected behaviors.
529538
*
530539
* @since 1.4.0
531540
*/
532-
def prepareJobForWrite(job: Job): Unit = ()
533-
534-
/**
535-
* This method is responsible for producing a new [[OutputWriter]] for each newly opened output
536-
* file on the executor side.
537-
*
538-
* @since 1.4.0
539-
*/
540-
def outputWriterClass: Class[_ <: OutputWriter]
541+
def prepareJobForWrite(job: Job): OutputWriterFactory
541542
}

sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import com.google.common.base.Objects
2424
import org.apache.hadoop.fs.Path
2525
import org.apache.hadoop.io.{NullWritable, Text}
2626
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
27-
import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext}
27+
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
2828

2929
import org.apache.spark.rdd.RDD
3030
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
@@ -59,24 +59,16 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW
5959
}
6060
}
6161

62-
class SimpleTextOutputWriter extends OutputWriter {
63-
private var recordWriter: RecordWriter[NullWritable, Text] = _
64-
private var taskAttemptContext: TaskAttemptContext = _
65-
66-
override def init(
67-
path: String,
68-
dataSchema: StructType,
69-
context: TaskAttemptContext): Unit = {
70-
recordWriter = new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context)
71-
taskAttemptContext = context
72-
}
62+
class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter {
63+
private val recordWriter: RecordWriter[NullWritable, Text] =
64+
new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context)
7365

7466
override def write(row: Row): Unit = {
7567
val serialized = row.toSeq.map(_.toString).mkString(",")
7668
recordWriter.write(null, new Text(serialized))
7769
}
7870

79-
override def close(): Unit = recordWriter.close(taskAttemptContext)
71+
override def close(): Unit = recordWriter.close(context)
8072
}
8173

8274
/**
@@ -110,9 +102,6 @@ class SimpleTextRelation(
110102
override def hashCode(): Int =
111103
Objects.hashCode(paths, maybeDataSchema, dataSchema)
112104

113-
override def outputWriterClass: Class[_ <: OutputWriter] =
114-
classOf[SimpleTextOutputWriter]
115-
116105
override def buildScan(inputPaths: Array[String]): RDD[Row] = {
117106
val fields = dataSchema.map(_.dataType)
118107

@@ -122,4 +111,13 @@ class SimpleTextRelation(
122111
}: _*)
123112
}
124113
}
114+
115+
override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory {
116+
override def newInstance(
117+
path: String,
118+
dataSchema: StructType,
119+
context: TaskAttemptContext): OutputWriter = {
120+
new SimpleTextOutputWriter(path, context)
121+
}
122+
}
125123
}

0 commit comments

Comments
 (0)