Skip to content

Commit

Permalink
[SPARK-48394][CORE] Cleanup mapIdToMapIndex on mapoutput unregister
Browse files Browse the repository at this point in the history
This PR cleans up `mapIdToMapIndex` when the corresponding mapstatus is unregistered in three places:
* `removeMapOutput`
* `removeOutputsByFilter`
* `addMapOutput` (old mapstatus overwritten)

There is only one valid mapstatus for the same `mapIndex` at the same time in Spark. `mapIdToMapIndex` should also follows the same rule to avoid chaos.

No.

Unit tests.

No.

Closes #46706 from Ngone51/SPARK-43043-followup.

Lead-authored-by: Yi Wu <yi.wu@databricks.com>
Co-authored-by: wuyi <yi.wu@databricks.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
Ngone51 committed May 31, 2024
1 parent d64f96c commit 15ed5a0
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 9 deletions.
26 changes: 17 additions & 9 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus}
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId}
import org.apache.spark.util._
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

/**
Expand Down Expand Up @@ -151,17 +150,22 @@ private class ShuffleStatus(
/**
* Mapping from a mapId to the mapIndex, this is required to reduce the searching overhead within
* the function updateMapOutput(mapId, bmAddress).
*
* Exposed for testing.
*/
private[this] val mapIdToMapIndex = new OpenHashMap[Long, Int]()
private[spark] val mapIdToMapIndex = new HashMap[Long, Int]()

/**
* Register a map output. If there is already a registered location for the map output then it
* will be replaced by the new location.
*/
def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock {
if (mapStatuses(mapIndex) == null) {
val currentMapStatus = mapStatuses(mapIndex)
if (currentMapStatus == null) {
_numAvailableMapOutputs += 1
invalidateSerializedMapOutputStatusCache()
} else {
mapIdToMapIndex.remove(currentMapStatus.mapId)
}
mapStatuses(mapIndex) = status
mapIdToMapIndex(status.mapId) = mapIndex
Expand Down Expand Up @@ -190,8 +194,8 @@ private class ShuffleStatus(
mapStatus.updateLocation(bmAddress)
invalidateSerializedMapOutputStatusCache()
case None =>
if (mapIndex.map(mapStatusesDeleted).exists(_.mapId == mapId)) {
val index = mapIndex.get
val index = mapStatusesDeleted.indexWhere(x => x != null && x.mapId == mapId)
if (index >= 0 && mapStatuses(index) == null) {
val mapStatus = mapStatusesDeleted(index)
mapStatus.updateLocation(bmAddress)
mapStatuses(index) = mapStatus
Expand All @@ -216,9 +220,11 @@ private class ShuffleStatus(
*/
def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}")
if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) {
val currentMapStatus = mapStatuses(mapIndex)
if (currentMapStatus != null && currentMapStatus.location == bmAddress) {
_numAvailableMapOutputs -= 1
mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex)
mapIdToMapIndex.remove(currentMapStatus.mapId)
mapStatusesDeleted(mapIndex) = currentMapStatus
mapStatuses(mapIndex) = null
invalidateSerializedMapOutputStatusCache()
}
Expand Down Expand Up @@ -284,9 +290,11 @@ private class ShuffleStatus(
*/
def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock {
for (mapIndex <- mapStatuses.indices) {
if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) {
val currentMapStatus = mapStatuses(mapIndex)
if (currentMapStatus != null && f(currentMapStatus.location)) {
_numAvailableMapOutputs -= 1
mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex)
mapIdToMapIndex.remove(currentMapStatus.mapId)
mapStatusesDeleted(mapIndex) = currentMapStatus
mapStatuses(mapIndex) = null
invalidateSerializedMapOutputStatusCache()
}
Expand Down
55 changes: 55 additions & 0 deletions core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1109,4 +1109,59 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
rpcEnv.shutdown()
}
}

test(
"SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after removeOutputsByFilter"
) {
val rpcEnv = createRpcEnv("test")
val tracker = newTrackerMaster()
try {
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(0, 1, 1)
tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000),
Array(2L), 0))
tracker.removeOutputsOnHost("hostA")
assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size == 0)
} finally {
tracker.stop()
rpcEnv.shutdown()
}
}

test("SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after unregisterMapOutput") {
val rpcEnv = createRpcEnv("test")
val tracker = newTrackerMaster()
try {
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(0, 1, 1)
tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000),
Array(2L), 0))
tracker.unregisterMapOutput(0, 0, BlockManagerId("exec-1", "hostA", 1000))
assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size == 0)
} finally {
tracker.stop()
rpcEnv.shutdown()
}
}

test("SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after registerMapOutput") {
val rpcEnv = createRpcEnv("test")
val tracker = newTrackerMaster()
try {
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(0, 1, 1)
tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000),
Array(2L), 0))
// Another task also finished working on partition 0.
tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-2", "hostB", 1000),
Array(2L), 1))
assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size == 1)
} finally {
tracker.stop()
rpcEnv.shutdown()
}
}
}

0 comments on commit 15ed5a0

Please sign in to comment.