@@ -115,8 +115,8 @@ private[spark] class UnsafeShuffleWriter[K, V](
115
115
val serializer = Serializer .getSerializer(dep.serializer).newInstance()
116
116
val PAGE_SIZE = 1024 * 1024 * 1
117
117
118
- var currentPage : MemoryBlock = memoryManager.allocatePage( PAGE_SIZE )
119
- var currentPagePosition : Long = currentPage.getBaseOffset
118
+ var currentPage : MemoryBlock = null
119
+ var currentPagePosition : Long = PAGE_SIZE
120
120
121
121
def ensureSpaceInDataPage (spaceRequired : Long ): Unit = {
122
122
if (spaceRequired > PAGE_SIZE ) {
@@ -143,6 +143,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
143
143
serBufferSerStream.flush()
144
144
145
145
val serializedRecordSize = byteBuffer.position()
146
+ assert(serializedRecordSize > 0 )
146
147
// TODO: we should run the partition extraction function _now_, at insert time, rather than
147
148
// requiring it to be stored alongisde the data, since this may lead to double storage
148
149
val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8
@@ -152,17 +153,17 @@ private[spark] class UnsafeShuffleWriter[K, V](
152
153
memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition)
153
154
PlatformDependent .UNSAFE .putLong(currentPage.getBaseObject, currentPagePosition, partitionId)
154
155
currentPagePosition += 8
155
- println(" The stored record length is " + byteBuffer.position() )
156
+ println(" The stored record length is " + serializedRecordSize )
156
157
PlatformDependent .UNSAFE .putLong(
157
- currentPage.getBaseObject, currentPagePosition, byteBuffer.position() )
158
+ currentPage.getBaseObject, currentPagePosition, serializedRecordSize )
158
159
currentPagePosition += 8
159
160
PlatformDependent .copyMemory(
160
161
serArray,
161
162
PlatformDependent .BYTE_ARRAY_OFFSET ,
162
163
currentPage.getBaseObject,
163
164
currentPagePosition,
164
- byteBuffer.position() )
165
- currentPagePosition += byteBuffer.position()
165
+ serializedRecordSize )
166
+ currentPagePosition += serializedRecordSize
166
167
println(" After writing record, current page position is " + currentPagePosition)
167
168
sorter.insertRecord(newRecordAddress)
168
169
@@ -195,10 +196,12 @@ private[spark] class UnsafeShuffleWriter[K, V](
195
196
}
196
197
197
198
def switchToPartition (newPartition : Int ): Unit = {
199
+ assert (newPartition > currentPartition, s " new partition $newPartition should be >= $currentPartition" )
198
200
if (currentPartition != - 1 ) {
199
201
closePartition()
200
202
prevPartitionLength = partitionLengths(currentPartition)
201
203
}
204
+ println(s " Before switching to partition $newPartition, partition lengths are " + partitionLengths.toSeq)
202
205
currentPartition = newPartition
203
206
out = blockManager.wrapForCompression(blockId, new FileOutputStream (outputFile, true ))
204
207
}
@@ -214,11 +217,11 @@ private[spark] class UnsafeShuffleWriter[K, V](
214
217
val recordLength = PlatformDependent .UNSAFE .getLong(baseObject, baseOffset + 8 )
215
218
println(" Base offset is " + baseOffset)
216
219
println(" Record length is " + recordLength)
217
- var i : Int = 0
218
220
// TODO: need to have a way to figure out whether a serializer supports relocation of
219
221
// serialized objects or not. Sandy also ran into this in his patch (see
220
222
// https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might
221
223
// as well just bypass this optimized code path in favor of the old one.
224
+ var i : Int = 0
222
225
while (i < recordLength) {
223
226
out.write(PlatformDependent .UNSAFE .getByte(baseObject, baseOffset + 16 + i))
224
227
i += 1
@@ -241,6 +244,14 @@ private[spark] class UnsafeShuffleWriter[K, V](
241
244
mapStatus = MapStatus (blockManager.shuffleServerId, partitionLengths)
242
245
}
243
246
247
+ private def freeMemory (): Unit = {
248
+ val iter = allocatedPages.iterator()
249
+ while (iter.hasNext) {
250
+ memoryManager.freePage(iter.next())
251
+ iter.remove()
252
+ }
253
+ }
254
+
244
255
/** Close this writer, passing along whether the map completed */
245
256
override def stop (success : Boolean ): Option [MapStatus ] = {
246
257
println(" Stopping unsafeshufflewriter" )
@@ -249,6 +260,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
249
260
None
250
261
} else {
251
262
stopping = true
263
+ freeMemory()
252
264
if (success) {
253
265
Option (mapStatus)
254
266
} else {
@@ -258,24 +270,14 @@ private[spark] class UnsafeShuffleWriter[K, V](
258
270
}
259
271
}
260
272
} finally {
261
- // Clean up our sorter, which may have its own intermediate files
262
- if (! allocatedPages.isEmpty) {
263
- val iter = allocatedPages.iterator()
264
- while (iter.hasNext) {
265
- memoryManager.freePage(iter.next())
266
- iter.remove()
267
- }
268
- val startTime = System .nanoTime()
269
- // sorter.stop()
270
- context.taskMetrics().shuffleWriteMetrics.foreach(
271
- _.incShuffleWriteTime(System .nanoTime - startTime))
272
- }
273
+ freeMemory()
274
+ val startTime = System .nanoTime()
275
+ context.taskMetrics().shuffleWriteMetrics.foreach(
276
+ _.incShuffleWriteTime(System .nanoTime - startTime))
273
277
}
274
278
}
275
279
}
276
280
277
-
278
-
279
281
private [spark] class UnsafeShuffleManager (conf : SparkConf ) extends ShuffleManager {
280
282
281
283
private [this ] val sortShuffleManager : SortShuffleManager = new SortShuffleManager (conf)
0 commit comments