Skip to content

[SPARK-3796] Create external service which can serve shuffle files #3001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
<artifactId>spark-network-common_2.10</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-network-shuffle_2.10</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>net.java.dev.jets3t</groupId>
<artifactId>jets3t</artifactId>
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
} else {
logError("Missing all output locations for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
}
Expand Down Expand Up @@ -348,7 +349,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
new ConcurrentHashMap[Int, Array[MapStatus]]
}

private[spark] object MapOutputTracker {
private[spark] object MapOutputTracker extends Logging {

// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
Expand Down Expand Up @@ -381,6 +382,7 @@ private[spark] object MapOutputTracker {
statuses.map {
status =>
if (status == null) {
logError("Missing an output location for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
} else {
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.netty.{NettyBlockTransferService}
import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ private[spark] class Worker(
private def retryConnectToMaster() {
Utils.tryOrExit {
connectionAttemptCount += 1
logInfo(s"Attempting to connect to master (attempt # $connectionAttemptCount")
if (registered) {
registrationRetryTimer.foreach(_.cancel())
registrationRetryTimer = None
} else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) {
logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)")
tryRegisterAllMasters()
if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) {
registrationRetryTimer.foreach(_.cancel())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ private[spark] class Executor(
val executorSource = new ExecutorSource(this, executorId)

// Initialize Spark environment (using system properties read above)
conf.set("spark.executor.id", "executor." + executorId)
conf.set("spark.executor.id", executorId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this change about?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was introduced recently, and I was planning on using it, but ended up not. Still, I was inclined to keep the seemingly more sensible semantics of "spark.executor.id" being the executorId rather than being prefixed. It is currently only used by the "MetricsSystem".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense. This was introduced in a patch that was merged not long ago (middle of 1.2 window) so it's OK to change it.

private val env = {
if (!isLocal) {
val port = conf.getInt("spark.executor.port", 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ package org.apache.spark.network
import java.io.Closeable
import java.nio.ByteBuffer

import scala.concurrent.{Await, Future}
import scala.concurrent.{Promise, Await, Future}
import scala.concurrent.duration.Duration

import org.apache.spark.Logging
import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener}
import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel}

private[spark]
abstract class BlockTransferService extends Closeable with Logging {
abstract class BlockTransferService extends ShuffleClient with Closeable with Logging {

/**
* Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
Expand Down Expand Up @@ -60,10 +60,11 @@ abstract class BlockTransferService extends Closeable with Logging {
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
*/
def fetchBlocks(
hostName: String,
override def fetchBlocks(
host: String,
port: Int,
blockIds: Seq[String],
execId: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

execId should be part of the connection establishment / registration and not part of fetchBlocks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tricky part here is that "execId" is actually part of the request. I am fetching Executor 6's blocks, while I am myself Executor 4. So there is no API that is exposed at a lower layer to transfer the execId.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ic - does each executor have its own path for shuffle files?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, each executor registers its ExecutorShuffleInfo, which includes its own localDirs (created by the Executor on initialization).

blockIds: Array[String],
listener: BlockFetchingListener): Unit

/**
Expand All @@ -81,43 +82,23 @@ abstract class BlockTransferService extends Closeable with Logging {
*
* It is also only available after [[init]] is invoked.
*/
def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = {
def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = {
// A monitor for the thread to wait on.
val lock = new Object
@volatile var result: Either[ManagedBuffer, Throwable] = null
fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener {
override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
lock.synchronized {
result = Right(exception)
lock.notify()
val result = Promise[ManagedBuffer]()
fetchBlocks(host, port, execId, Array(blockId),
new BlockFetchingListener {
override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
result.failure(exception)
}
}
override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
lock.synchronized {
override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
val ret = ByteBuffer.allocate(data.size.toInt)
ret.put(data.nioByteBuffer())
ret.flip()
result = Left(new NioManagedBuffer(ret))
lock.notify()
result.success(new NioManagedBuffer(ret))
}
}
})
})

// Sleep until result is no longer null
lock.synchronized {
while (result == null) {
try {
lock.wait()
} catch {
case e: InterruptedException =>
}
}
}

result match {
case Left(data) => data
case Right(e) => throw e
}
Await.result(result.future, Duration.Inf)
}

/**
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,41 @@ package org.apache.spark.network.netty

import java.nio.ByteBuffer

import scala.collection.JavaConversions._

import org.apache.spark.Logging
import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
import org.apache.spark.network.shuffle.ShuffleStreamHandle
import org.apache.spark.serializer.Serializer
import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
import org.apache.spark.network.client.{TransportClient, RpcResponseCallback}
import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler}
import org.apache.spark.storage.{StorageLevel, BlockId}

import scala.collection.JavaConversions._
import org.apache.spark.storage.{BlockId, StorageLevel}

object NettyMessages {

/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
case class OpenBlocks(blockIds: Seq[BlockId])

/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel)

/** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */
case class ShuffleStreamHandle(streamId: Long, numChunks: Int)
}

/**
* Serves requests to open blocks by simply registering one chunk per block requested.
* Handles opening and uploading arbitrary BlockManager blocks.
*
* Opened blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk
* is equivalent to one Spark-level shuffle block.
*/
class NettyBlockRpcServer(
serializer: Serializer,
streamManager: DefaultStreamManager,
blockManager: BlockDataManager)
extends RpcHandler with Logging {

import NettyMessages._

private val streamManager = new OneForOneStreamManager()

override def receive(
client: TransportClient,
messageBytes: Array[Byte],
Expand All @@ -73,4 +75,6 @@ class NettyBlockRpcServer(
responseContext.onSuccess(new Array[Byte](0))
}
}

override def getStreamManager(): StreamManager = streamManager
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

package org.apache.spark.network.netty

import scala.concurrent.{Promise, Future}
import scala.concurrent.{Future, Promise}

import org.apache.spark.SparkConf
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient, TransportClientFactory}
import org.apache.spark.network.netty.NettyMessages.UploadBlock
import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory}
import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
import org.apache.spark.network.server._
import org.apache.spark.network.util.{ConfigProvider, TransportConf}
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
Expand All @@ -37,30 +37,29 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
val serializer = new JavaSerializer(conf)

// Create a TransportConfig using SparkConf.
private[this] val transportConf = new TransportConf(
new ConfigProvider { override def get(name: String) = conf.get(name) })

private[this] var transportContext: TransportContext = _
private[this] var server: TransportServer = _
private[this] var clientFactory: TransportClientFactory = _

override def init(blockDataManager: BlockDataManager): Unit = {
val streamManager = new DefaultStreamManager
val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager)
transportContext = new TransportContext(transportConf, streamManager, rpcHandler)
val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler)
clientFactory = transportContext.createClientFactory()
server = transportContext.createServer()
logInfo("Server created on " + server.getPort)
}

override def fetchBlocks(
hostname: String,
host: String,
port: Int,
blockIds: Seq[String],
execId: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above - best to leave execId out of this

blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val client = clientFactory.createClient(hostname, port)
new NettyBlockFetcher(serializer, client, blockIds, listener).start()
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, blockIds.toArray, listener)
.start(OpenBlocks(blockIds.map(BlockId.apply)))
} catch {
case e: Exception =>
logError("Exception while beginning fetchBlocks", e)
Expand Down
Loading