@@ -21,28 +21,32 @@ import java.io.IOException
21
21
import java .text .NumberFormat
22
22
import java .util .Date
23
23
24
+ import scala .collection .mutable
25
+
24
26
import org .apache .hadoop .fs .Path
27
+ import org .apache .hadoop .hive .conf .HiveConf .ConfVars
25
28
import org .apache .hadoop .hive .ql .exec .{FileSinkOperator , Utilities }
26
29
import org .apache .hadoop .hive .ql .io .{HiveFileFormatUtils , HiveOutputFormat }
27
30
import org .apache .hadoop .hive .ql .plan .FileSinkDesc
28
- import org .apache .hadoop .mapred ._
29
31
import org .apache .hadoop .io .Writable
32
+ import org .apache .hadoop .mapred ._
30
33
34
+ import org .apache .spark .sql .Row
31
35
import org .apache .spark .{Logging , SerializableWritable , SparkHadoopWriter }
32
36
33
37
/**
34
38
* Internal helper class that saves an RDD using a Hive OutputFormat.
35
39
* It is based on [[SparkHadoopWriter ]].
36
40
*/
37
- private [hive] class SparkHiveHadoopWriter (
41
+ private [hive] class SparkHiveWriterContainer (
38
42
@ transient jobConf : JobConf ,
39
43
fileSinkConf : FileSinkDesc )
40
44
extends Logging
41
45
with SparkHadoopMapRedUtil
42
46
with Serializable {
43
47
44
48
private val now = new Date ()
45
- private val conf = new SerializableWritable (jobConf)
49
+ protected val conf = new SerializableWritable (jobConf)
46
50
47
51
private var jobID = 0
48
52
private var splitID = 0
@@ -51,152 +55,75 @@ private[hive] class SparkHiveHadoopWriter(
51
55
private var taID : SerializableWritable [TaskAttemptID ] = null
52
56
53
57
@ transient private var writer : FileSinkOperator .RecordWriter = null
54
- @ transient private var format : HiveOutputFormat [AnyRef , Writable ] = null
55
- @ transient private var committer : OutputCommitter = null
56
- @ transient private var jobContext : JobContext = null
57
- @ transient private var taskContext : TaskAttemptContext = null
58
+ @ transient private lazy val committer = conf.value.getOutputCommitter
59
+ @ transient private lazy val jobContext = newJobContext(conf.value, jID.value)
60
+ @ transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value)
61
+ @ transient private lazy val outputFormat =
62
+ conf.value.getOutputFormat.asInstanceOf [HiveOutputFormat [AnyRef ,Writable ]]
58
63
59
- def preSetup () {
64
+ def driverSideSetup () {
60
65
setIDs(0 , 0 , 0 )
61
66
setConfParams()
62
-
63
- val jCtxt = getJobContext()
64
- getOutputCommitter().setupJob(jCtxt)
67
+ committer.setupJob(jobContext)
65
68
}
66
69
67
-
68
- def setup (jobid : Int , splitid : Int , attemptid : Int ) {
69
- setIDs(jobid, splitid, attemptid)
70
+ def executorSideSetup (jobId : Int , splitId : Int , attemptId : Int ) {
71
+ setIDs(jobId, splitId, attemptId)
70
72
setConfParams()
71
- }
72
-
73
- def open () {
74
- val numfmt = NumberFormat .getInstance()
75
- numfmt.setMinimumIntegerDigits(5 )
76
- numfmt.setGroupingUsed(false )
77
-
78
- val extension = Utilities .getFileExtension(
79
- conf.value,
80
- fileSinkConf.getCompressed,
81
- getOutputFormat())
82
-
83
- val outputName = " part-" + numfmt.format(splitID) + extension
84
- val path = FileOutputFormat .getTaskOutputPath(conf.value, outputName)
85
-
86
- getOutputCommitter().setupTask(getTaskContext())
87
- writer = HiveFileFormatUtils .getHiveRecordWriter(
88
- conf.value,
89
- fileSinkConf.getTableInfo,
90
- conf.value.getOutputValueClass.asInstanceOf [Class [Writable ]],
91
- fileSinkConf,
92
- path,
93
- null )
73
+ committer.setupTask(taskContext)
94
74
}
95
75
96
76
/**
97
- * create an HiveRecordWriter. imitate the above function open()
98
- * @param dynamicPartPath the relative path for dynamic partition
99
- *
100
- * since this function is used to create different writer for
101
- * different dynamic partition.So we need a parameter dynamicPartPath
102
- * and use it we can calculate a new path and pass the new path to
103
- * the function HiveFileFormatUtils.getHiveRecordWriter
77
+ * Create a `HiveRecordWriter`. A relative dynamic partition path can be used to create a writer
78
+ * for writing data to a dynamic partition.
104
79
*/
105
- def open (dynamicPartPath : String ) {
106
- val numfmt = NumberFormat .getInstance()
107
- numfmt.setMinimumIntegerDigits(5 )
108
- numfmt.setGroupingUsed(false )
109
-
110
- val extension = Utilities .getFileExtension(
111
- conf.value,
112
- fileSinkConf.getCompressed,
113
- getOutputFormat())
114
-
115
- val outputName = " part-" + numfmt.format(splitID) + extension
116
- val outputPath : Path = FileOutputFormat .getOutputPath(conf.value)
117
- if (outputPath == null ) {
118
- throw new IOException (" Undefined job output-path" )
119
- }
120
- val workPath = new Path (outputPath, dynamicPartPath.stripPrefix(" /" )) // remove "/"
121
- val path = new Path (workPath, outputName)
122
- getOutputCommitter().setupTask(getTaskContext())
80
+ def open () {
123
81
writer = HiveFileFormatUtils .getHiveRecordWriter(
124
82
conf.value,
125
83
fileSinkConf.getTableInfo,
126
84
conf.value.getOutputValueClass.asInstanceOf [Class [Writable ]],
127
85
fileSinkConf,
128
- path ,
86
+ FileOutputFormat .getTaskOutputPath(conf.value, getOutputName) ,
129
87
Reporter .NULL )
130
88
}
131
89
132
- def write ( value : Writable ) {
133
- if (writer != null ) {
134
- writer.write(value )
135
- } else {
136
- throw new IOException ( " Writer is null, open() has not been called " )
137
- }
90
+ protected def getOutputName : String = {
91
+ val numberFormat = NumberFormat .getInstance()
92
+ numberFormat.setMinimumIntegerDigits( 5 )
93
+ numberFormat.setGroupingUsed( false )
94
+ val extension = Utilities .getFileExtension(conf.value, fileSinkConf.getCompressed, outputFormat )
95
+ " part- " + numberFormat.format(splitID) + extension
138
96
}
139
97
98
+ def getLocalFileWriter (row : Row ): FileSinkOperator .RecordWriter = writer
99
+
140
100
def close () {
141
101
// Seems the boolean value passed into close does not matter.
142
102
writer.close(false )
143
103
}
144
104
145
105
def commit () {
146
- val taCtxt = getTaskContext()
147
- val cmtr = getOutputCommitter()
148
- if (cmtr.needsTaskCommit(taCtxt)) {
106
+ if (committer.needsTaskCommit(taskContext)) {
149
107
try {
150
- cmtr .commitTask(taCtxt )
108
+ committer .commitTask(taskContext )
151
109
logInfo (taID + " : Committed" )
152
110
} catch {
153
111
case e : IOException =>
154
112
logError(" Error committing the output of task: " + taID.value, e)
155
- cmtr .abortTask(taCtxt )
113
+ committer .abortTask(taskContext )
156
114
throw e
157
115
}
158
116
} else {
159
- logWarning (" No need to commit output of task: " + taID.value)
117
+ logInfo (" No need to commit output of task: " + taID.value)
160
118
}
161
119
}
162
120
163
121
def commitJob () {
164
- // always ? Or if cmtr.needsTaskCommit ?
165
- val cmtr = getOutputCommitter()
166
- cmtr.commitJob(getJobContext())
122
+ committer.commitJob(jobContext)
167
123
}
168
124
169
125
// ********* Private Functions *********
170
126
171
- private def getOutputFormat (): HiveOutputFormat [AnyRef ,Writable ] = {
172
- if (format == null ) {
173
- format = conf.value.getOutputFormat()
174
- .asInstanceOf [HiveOutputFormat [AnyRef ,Writable ]]
175
- }
176
- format
177
- }
178
-
179
- private def getOutputCommitter (): OutputCommitter = {
180
- if (committer == null ) {
181
- committer = conf.value.getOutputCommitter
182
- }
183
- committer
184
- }
185
-
186
- private def getJobContext (): JobContext = {
187
- if (jobContext == null ) {
188
- jobContext = newJobContext(conf.value, jID.value)
189
- }
190
- jobContext
191
- }
192
-
193
- private def getTaskContext (): TaskAttemptContext = {
194
- if (taskContext == null ) {
195
- taskContext = newTaskAttemptContext(conf.value, taID.value)
196
- }
197
- taskContext
198
- }
199
-
200
127
private def setIDs (jobId : Int , splitId : Int , attemptId : Int ) {
201
128
jobID = jobId
202
129
splitID = splitId
@@ -216,7 +143,7 @@ private[hive] class SparkHiveHadoopWriter(
216
143
}
217
144
}
218
145
219
- private [hive] object SparkHiveHadoopWriter {
146
+ private [hive] object SparkHiveWriterContainer {
220
147
def createPathFromString (path : String , conf : JobConf ): Path = {
221
148
if (path == null ) {
222
149
throw new IllegalArgumentException (" Output path is null" )
@@ -226,6 +153,61 @@ private[hive] object SparkHiveHadoopWriter {
226
153
if (outputPath == null || fs == null ) {
227
154
throw new IllegalArgumentException (" Incorrectly formatted output path" )
228
155
}
229
- outputPath.makeQualified(fs)
156
+ outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
157
+ }
158
+ }
159
+
160
+ private [spark] class SparkHiveDynamicPartitionWriterContainer (
161
+ @ transient jobConf : JobConf ,
162
+ fileSinkConf : FileSinkDesc ,
163
+ dynamicPartColNames : Array [String ])
164
+ extends SparkHiveWriterContainer (jobConf, fileSinkConf) {
165
+
166
+ private val defaultPartName = jobConf.get(
167
+ ConfVars .DEFAULTPARTITIONNAME .varname, ConfVars .DEFAULTPARTITIONNAME .defaultVal)
168
+
169
+ @ transient private var writers : mutable.HashMap [String , FileSinkOperator .RecordWriter ] = _
170
+
171
+ override def open (): Unit = {
172
+ writers = mutable.HashMap .empty[String , FileSinkOperator .RecordWriter ]
173
+ }
174
+
175
+ override def close (): Unit = {
176
+ writers.values.foreach(_.close(false ))
177
+ }
178
+
179
+ override def getLocalFileWriter (row : Row ): FileSinkOperator .RecordWriter = {
180
+ val dynamicPartPath = dynamicPartColNames
181
+ .zip(row.takeRight(dynamicPartColNames.length))
182
+ .map { case (col, rawVal) =>
183
+ val string = String .valueOf(rawVal)
184
+ s " / $col= ${if (rawVal == null || string.isEmpty) defaultPartName else string}"
185
+ }
186
+ .mkString
187
+
188
+ val path = {
189
+ val outputPath = FileOutputFormat .getOutputPath(conf.value)
190
+ assert(outputPath != null , " Undefined job output-path" )
191
+ val workPath = new Path (outputPath, dynamicPartPath.stripPrefix(" /" ))
192
+ new Path (workPath, getOutputName)
193
+ }
194
+
195
+ def newWriter = {
196
+ val newFileSinkDesc = new FileSinkDesc (
197
+ fileSinkConf.getDirName + dynamicPartPath,
198
+ fileSinkConf.getTableInfo,
199
+ fileSinkConf.getCompressed)
200
+ newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec)
201
+ newFileSinkDesc.setCompressType(fileSinkConf.getCompressType)
202
+ HiveFileFormatUtils .getHiveRecordWriter(
203
+ conf.value,
204
+ fileSinkConf.getTableInfo,
205
+ conf.value.getOutputValueClass.asInstanceOf [Class [Writable ]],
206
+ newFileSinkDesc,
207
+ path,
208
+ Reporter .NULL )
209
+ }
210
+
211
+ writers.getOrElseUpdate(dynamicPartPath, newWriter)
230
212
}
231
213
}
0 commit comments