Skip to content

Commit 253f13e

Browse files
committed
More cleanup
1 parent 8e3ec20 commit 253f13e

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ private[spark] class UnsafeShuffleWriter[K, V](
115115
val serializer = Serializer.getSerializer(dep.serializer).newInstance()
116116
val PAGE_SIZE = 1024 * 1024 * 1
117117

118-
var currentPage: MemoryBlock = memoryManager.allocatePage(PAGE_SIZE)
119-
var currentPagePosition: Long = currentPage.getBaseOffset
118+
var currentPage: MemoryBlock = null
119+
var currentPagePosition: Long = PAGE_SIZE
120120

121121
def ensureSpaceInDataPage(spaceRequired: Long): Unit = {
122122
if (spaceRequired > PAGE_SIZE) {
@@ -143,6 +143,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
143143
serBufferSerStream.flush()
144144

145145
val serializedRecordSize = byteBuffer.position()
146+
assert(serializedRecordSize > 0)
146147
// TODO: we should run the partition extraction function _now_, at insert time, rather than
147148
// requiring it to be stored alongisde the data, since this may lead to double storage
148149
val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8
@@ -152,17 +153,17 @@ private[spark] class UnsafeShuffleWriter[K, V](
152153
memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition)
153154
PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId)
154155
currentPagePosition += 8
155-
println("The stored record length is " + byteBuffer.position())
156+
println("The stored record length is " + serializedRecordSize)
156157
PlatformDependent.UNSAFE.putLong(
157-
currentPage.getBaseObject, currentPagePosition, byteBuffer.position())
158+
currentPage.getBaseObject, currentPagePosition, serializedRecordSize)
158159
currentPagePosition += 8
159160
PlatformDependent.copyMemory(
160161
serArray,
161162
PlatformDependent.BYTE_ARRAY_OFFSET,
162163
currentPage.getBaseObject,
163164
currentPagePosition,
164-
byteBuffer.position())
165-
currentPagePosition += byteBuffer.position()
165+
serializedRecordSize)
166+
currentPagePosition += serializedRecordSize
166167
println("After writing record, current page position is " + currentPagePosition)
167168
sorter.insertRecord(newRecordAddress)
168169

@@ -195,10 +196,12 @@ private[spark] class UnsafeShuffleWriter[K, V](
195196
}
196197

197198
def switchToPartition(newPartition: Int): Unit = {
199+
assert (newPartition > currentPartition, s"new partition $newPartition should be >= $currentPartition")
198200
if (currentPartition != -1) {
199201
closePartition()
200202
prevPartitionLength = partitionLengths(currentPartition)
201203
}
204+
println(s"Before switching to partition $newPartition, partition lengths are " + partitionLengths.toSeq)
202205
currentPartition = newPartition
203206
out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true))
204207
}
@@ -214,11 +217,11 @@ private[spark] class UnsafeShuffleWriter[K, V](
214217
val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8)
215218
println("Base offset is " + baseOffset)
216219
println("Record length is " + recordLength)
217-
var i: Int = 0
218220
// TODO: need to have a way to figure out whether a serializer supports relocation of
219221
// serialized objects or not. Sandy also ran into this in his patch (see
220222
// https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might
221223
// as well just bypass this optimized code path in favor of the old one.
224+
var i: Int = 0
222225
while (i < recordLength) {
223226
out.write(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + 16 + i))
224227
i += 1
@@ -241,6 +244,14 @@ private[spark] class UnsafeShuffleWriter[K, V](
241244
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
242245
}
243246

247+
private def freeMemory(): Unit = {
248+
val iter = allocatedPages.iterator()
249+
while (iter.hasNext) {
250+
memoryManager.freePage(iter.next())
251+
iter.remove()
252+
}
253+
}
254+
244255
/** Close this writer, passing along whether the map completed */
245256
override def stop(success: Boolean): Option[MapStatus] = {
246257
println("Stopping unsafeshufflewriter")
@@ -249,6 +260,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
249260
None
250261
} else {
251262
stopping = true
263+
freeMemory()
252264
if (success) {
253265
Option(mapStatus)
254266
} else {
@@ -258,24 +270,14 @@ private[spark] class UnsafeShuffleWriter[K, V](
258270
}
259271
}
260272
} finally {
261-
// Clean up our sorter, which may have its own intermediate files
262-
if (!allocatedPages.isEmpty) {
263-
val iter = allocatedPages.iterator()
264-
while (iter.hasNext) {
265-
memoryManager.freePage(iter.next())
266-
iter.remove()
267-
}
268-
val startTime = System.nanoTime()
269-
//sorter.stop()
270-
context.taskMetrics().shuffleWriteMetrics.foreach(
271-
_.incShuffleWriteTime(System.nanoTime - startTime))
272-
}
273+
freeMemory()
274+
val startTime = System.nanoTime()
275+
context.taskMetrics().shuffleWriteMetrics.foreach(
276+
_.incShuffleWriteTime(System.nanoTime - startTime))
273277
}
274278
}
275279
}
276280

277-
278-
279281
private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager {
280282

281283
private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)

0 commit comments

Comments
 (0)