Skip to content

[WIP] Add metadata to MapStatus #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.shuffle.api.metadata;

import java.io.Serializable;
import org.apache.spark.annotation.DeveloperApi;

import java.io.Externalizable;

/**
* :: Private ::
* An opaque metadata tag for registering the result of committing the output of a
* shuffle map task.
* :: DeveloperApi ::
* Metadata for registering the result of committing the output of a shuffle map task.
* <p>
* All implementations must be serializable since this is sent from the executors to
* All implementations must be externalizable since this is sent from the executors to
* the driver.
*/
public interface MapOutputMetadata extends Serializable {}
@DeveloperApi
public interface MapOutputMetadata extends Externalizable {}

Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.api.metadata;

public interface MapOutputMetadataFactory {
MapOutputMetadata create();
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
partitionLengths = mapOutputWriter.commitAllPartitions(
ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
blockManager.shuffleServerId(), partitionLengths, mapId, null);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
Expand Down Expand Up @@ -179,7 +179,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {

partitionLengths = writePartitionedData(mapOutputWriter);
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
blockManager.shuffleServerId(), partitionLengths, mapId, null);
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ void closeAndWriteOutput() throws IOException {
}
}
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
blockManager.shuffleServerId(), partitionLengths, mapId, null);
}

@VisibleForTesting
Expand Down
81 changes: 81 additions & 0 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus}
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.shuffle.api.metadata.MapOutputMetadata
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId}
import org.apache.spark.util._
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
Expand Down Expand Up @@ -438,6 +439,7 @@ private[spark] case class GetMapAndMergeOutputMessage(shuffleId: Int,
context: RpcCallContext) extends MapOutputTrackerMasterMessage
private[spark] case class MapSizesByExecutorId(
iter: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], enableBatchFetch: Boolean)
case class BlockManagerIdWithMeta(blockManagerId: BlockManagerId, metadata: MapOutputMetadata)

/** RpcEndpoint class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterEndpoint(
Expand Down Expand Up @@ -540,6 +542,13 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

def getMapSizesByBmIdWithMetadata(
shuffleId: Int,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerIdWithMeta, Seq[(BlockId, Long, Int)])]

/**
* Called from executors to get the server URIs and output sizes for each shuffle block that
* needs to be read from a given range of map output partitions (startPartition is included but
Expand Down Expand Up @@ -1125,6 +1134,27 @@ private[spark] class MapOutputTrackerMaster(
}
}

override def getMapSizesByBmIdWithMetadata(
shuffleId: Int,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerIdWithMeta, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId")
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.withMapStatuses { statuses =>
val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex
logDebug(s"Convert map statuses for shuffle $shuffleId, " +
s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition")
MapOutputTracker.convertMapStatusesWithMeta(
shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex)
}
case None =>
Iterator.empty
}
}

// This method is only called in local-mode. Since push based shuffle won't be
// enabled in local-mode, this method returns empty list.
override def getMapSizesForMergeResult(
Expand Down Expand Up @@ -1182,6 +1212,32 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
*/
private val fetchingLock = new KeyLock[Int]

override def
getMapSizesByBmIdWithMetadata(
shuffleId: Int,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerIdWithMeta, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId")
val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf, false)
try {
val actualEndMapIndex =
if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex
logDebug(s"Convert map statuses for shuffle $shuffleId, " +
s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition")
MapOutputTracker.convertMapStatusesWithMeta(
shuffleId, startPartition, endPartition, mapOutputStatuses, startMapIndex,
actualEndMapIndex)
} catch {
case e: MetadataFetchFailedException =>
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
mapStatuses.clear()
mergeStatuses.clear()
throw e
}
}

override def getMapSizesByExecutorId(
shuffleId: Int,
startMapIndex: Int,
Expand Down Expand Up @@ -1577,6 +1633,31 @@ private[spark] object MapOutputTracker extends Logging {
MapSizesByExecutorId(splitsByAddress.mapValues(_.toSeq).iterator, enableBatchFetch)
}

def convertMapStatusesWithMeta(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
mapStatuses: Array[MapStatus],
startMapIndex : Int,
endMapIndex: Int): Iterator[(BlockManagerIdWithMeta, Seq[(BlockId, Long, Int)])] = {
assert (mapStatuses != null)
val splitsByAddress = new HashMap[BlockManagerIdWithMeta, ListBuffer[(BlockId, Long, Int)]]
val iter = mapStatuses.iterator.zipWithIndex
for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
validateStatus(status, shuffleId, startPartition)
for (part <- startPartition until endPartition) {
val size = status.getSizeForBlock(part)
if (size != 0) {
val blockManagerIdWithMeta = BlockManagerIdWithMeta(status.location, status.metadata)
splitsByAddress
.getOrElseUpdate(blockManagerIdWithMeta, ListBuffer()) +=
((ShuffleBlockId(shuffleId, status.mapId, part), size, mapIndex))
}
}
}
splitsByAddress.mapValues(_.toSeq).iterator
}

/**
* Given a shuffle ID, a partition ID, an array of map statuses, and bitmap corresponding
* to either a merged shuffle partition or a merged shuffle partition chunk, identify
Expand Down
52 changes: 41 additions & 11 deletions core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.roaringbitmap.RoaringBitmap

import org.apache.spark.SparkEnv
import org.apache.spark.internal.config
import org.apache.spark.shuffle.api.metadata.MapOutputMetadata
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils

Expand All @@ -45,6 +46,8 @@ private[spark] sealed trait MapStatus extends ShuffleOutputStatus {

def updateLocation(newLoc: BlockManagerId): Unit

def metadata: MapOutputMetadata

/**
* Estimated size for the reduce block, in bytes.
*
Expand Down Expand Up @@ -74,11 +77,12 @@ private[spark] object MapStatus {
def apply(
loc: BlockManagerId,
uncompressedSizes: Array[Long],
mapTaskId: Long): MapStatus = {
mapTaskId: Long,
metadata: MapOutputMetadata = null): MapStatus = {
if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) {
HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId)
HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId, metadata)
} else {
new CompressedMapStatus(loc, uncompressedSizes, mapTaskId)
new CompressedMapStatus(loc, uncompressedSizes, mapTaskId, metadata)
}
}

Expand Down Expand Up @@ -123,14 +127,19 @@ private[spark] object MapStatus {
private[spark] class CompressedMapStatus(
private[this] var loc: BlockManagerId,
private[this] var compressedSizes: Array[Byte],
private[this] var _mapTaskId: Long)
private[this] var _mapTaskId: Long,
private[this] var _metadata: MapOutputMetadata)
extends MapStatus with Externalizable {

// For deserialization only
protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1)
protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1, null)

def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long) = {
this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId)
def this(
loc: BlockManagerId,
uncompressedSizes: Array[Long],
mapTaskId: Long,
mapOutputMetadata: MapOutputMetadata = null) = {
this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId, mapOutputMetadata)
}

override def location: BlockManagerId = loc
Expand All @@ -145,11 +154,16 @@ private[spark] class CompressedMapStatus(

override def mapId: Long = _mapTaskId

override def metadata: MapOutputMetadata = _metadata

override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
out.writeLong(_mapTaskId)
if (_metadata != null) {
_metadata.writeExternal(out)
}
}

override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
Expand All @@ -158,6 +172,10 @@ private[spark] class CompressedMapStatus(
compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
_mapTaskId = in.readLong()
_metadata = SparkEnv.get.shuffleManager.mapOutputMetadataFactory.create()
if (_metadata != null) {
_metadata.readExternal(in)
}
}
}

Expand All @@ -179,15 +197,19 @@ private[spark] class HighlyCompressedMapStatus private (
private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long,
private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte],
private[this] var _mapTaskId: Long)
private[this] var _mapTaskId: Long,
private[this] var _metadata: MapOutputMetadata)
extends MapStatus with Externalizable {

override def metadata: MapOutputMetadata = _metadata

// loc could be null when the default constructor is called during deserialization
require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0
|| numNonEmptyBlocks == 0 || _mapTaskId > 0,
"Average size can only be zero for map stages that produced no output")

protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only
protected def this() =
this(null, -1, null, -1, null, -1, null) // For deserialization only

override def location: BlockManagerId = loc

Expand Down Expand Up @@ -219,6 +241,9 @@ private[spark] class HighlyCompressedMapStatus private (
out.writeByte(kv._2)
}
out.writeLong(_mapTaskId)
if (_metadata != null) {
_metadata.writeExternal(out)
}
}

override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
Expand All @@ -236,14 +261,19 @@ private[spark] class HighlyCompressedMapStatus private (
}
hugeBlockSizes = hugeBlockSizesImpl
_mapTaskId = in.readLong()
_metadata = SparkEnv.get.shuffleManager.mapOutputMetadataFactory.create()
if (_metadata != null) {
_metadata.readExternal(in)
}
}
}

private[spark] object HighlyCompressedMapStatus {
def apply(
loc: BlockManagerId,
uncompressedSizes: Array[Long],
mapTaskId: Long): HighlyCompressedMapStatus = {
mapTaskId: Long,
metadata: MapOutputMetadata = null): HighlyCompressedMapStatus = {
// We must keep track of which blocks are empty so that we don't report a zero-sized
// block as being non-empty (or vice-versa) when using the average block size.
var i = 0
Expand Down Expand Up @@ -310,6 +340,6 @@ private[spark] object HighlyCompressedMapStatus {
emptyBlocks.trim()
emptyBlocks.runOptimize()
new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
hugeBlockSizes, mapTaskId)
hugeBlockSizes, mapTaskId, metadata)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.shuffle

import org.apache.spark.{ShuffleDependency, TaskContext}
import org.apache.spark.shuffle.api.metadata.MapOutputMetadataFactory

/**
* Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on the driver
Expand Down Expand Up @@ -91,6 +92,8 @@ private[spark] trait ShuffleManager {
*/
def shuffleBlockResolver: ShuffleBlockResolver

val mapOutputMetadataFactory: MapOutputMetadataFactory

/** Shut down this ShuffleManager. */
def stop(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.shuffle.api.metadata.{MapOutputMetadata, MapOutputMetadataFactory}
import org.apache.spark.util.collection.OpenHashSet

/**
Expand Down Expand Up @@ -194,6 +195,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
override def stop(): Unit = {
shuffleBlockResolver.stop()
}

override val mapOutputMetadataFactory: MapOutputMetadataFactory = {
new MapOutputMetadataFactory {
override def create(): MapOutputMetadata = null
}
}
}


Expand Down
Loading