Skip to content

[SPARK-3386] Share and reuse SerializerInstances in shuffle paths #5606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.worker.WorkerWatcher
import org.apache.spark.scheduler.TaskDescription
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.{SignalLogger, Utils}

private[spark] class CoarseGrainedExecutorBackend(
Expand All @@ -47,6 +48,10 @@ private[spark] class CoarseGrainedExecutorBackend(
var executor: Executor = null
@volatile var driver: Option[RpcEndpointRef] = None

// If this CoarseGrainedExecutorBackend is changed to support multiple threads, then this may need
// to be changed so that we don't share the serializer instance across threads
private[this] val ser: SerializerInstance = env.closureSerializer.newInstance()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make a thread local version of this? i'm slightly worried we might change this from ThreadSafeRpcEndpoint to a non-thread-safe version (i.e. multiple threads), and then forget to change this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Even if we have to lookup the thread-local, that's probably cheaper than creating a whole new serializer for each task.


override def onStart() {
import scala.concurrent.ExecutionContext.Implicits.global
logInfo("Connecting to driver: " + driverUrl)
Expand Down Expand Up @@ -83,7 +88,6 @@ private[spark] class CoarseGrainedExecutorBackend(
logError("Received LaunchTask command but executor was null")
System.exit(1)
} else {
val ser = env.closureSerializer.newInstance()
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,12 @@ class FileShuffleBlockManager(conf: SparkConf)
private var fileGroup: ShuffleFileGroup = null

val openStartTime = System.nanoTime
val serializerInstance = serializer.newInstance()
val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
fileGroup = getUnusedFileGroup()
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize,
writeMetrics)
}
} else {
Expand All @@ -133,7 +134,8 @@ class FileShuffleBlockManager(conf: SparkConf)
logWarning(s"Failed to remove existing shuffle file $blockFile")
}
}
blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)
blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this line is called once for every bucket (reduce task), since it's enclosed in

Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
...

writeMetrics)
}
}
// Creating the file to write to and creating a disk writer both involve interacting with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.serializer.Serializer
import org.apache.spark.serializer.{SerializerInstance, Serializer}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.util._
Expand Down Expand Up @@ -646,13 +646,13 @@ private[spark] class BlockManager(
def getDiskWriter(
blockId: BlockId,
file: File,
serializer: Serializer,
serializerInstance: SerializerInstance,
bufferSize: Int,
writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites,
writeMetrics)
new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream,
syncWrites, writeMetrics)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream}
import java.nio.channels.FileChannel

import org.apache.spark.Logging
import org.apache.spark.serializer.{SerializationStream, Serializer}
import org.apache.spark.serializer.{SerializerInstance, SerializationStream}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -71,7 +71,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
private[spark] class DiskBlockObjectWriter(
blockId: BlockId,
file: File,
serializer: Serializer,
serializerInstance: SerializerInstance,
bufferSize: Int,
compressStream: OutputStream => OutputStream,
syncWrites: Boolean,
Expand Down Expand Up @@ -134,7 +134,7 @@ private[spark] class DiskBlockObjectWriter(
ts = new TimeTrackingOutputStream(fos)
channel = fos.getChannel()
bs = compressStream(new BufferedOutputStream(ts, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
objOut = serializerInstance.serializeStream(bs)
initialized = true
this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.Serializer
import org.apache.spark.serializer.{SerializerInstance, Serializer}
import org.apache.spark.util.{CompletionIterator, Utils}

/**
Expand Down Expand Up @@ -106,6 +106,8 @@ final class ShuffleBlockFetcherIterator(

private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()

private[this] val serializerInstance: SerializerInstance = serializer.newInstance()

/**
* Whether the iterator is still active. If isZombie is true, the callback interface will no
* longer place fetched blocks into [[results]].
Expand Down Expand Up @@ -299,7 +301,7 @@ final class ShuffleBlockFetcherIterator(
// the scheduler gets a FetchFailedException.
Try(buf.createInputStream()).map { is0 =>
val is = blockManager.wrapForCompression(blockId, is0)
val iter = serializer.newInstance().deserializeStream(is).asIterator
val iter = serializerInstance.deserializeStream(is).asIterator
CompletionIterator[Any, Iterator[Any]](iter, {
// Once the iterator is exhausted, release the buffer and set currentResult to null
// so we don't release it again in cleanup.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ class ExternalAppendOnlyMap[K, V, C](
override protected[this] def spill(collection: SizeTracker): Unit = {
val (blockId, file) = diskBlockManager.createTempLocalBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize,
curWriteMetrics)
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
var objectsWritten = 0

// List of batch sizes (bytes) in the order they are written to disk
Expand All @@ -179,8 +178,7 @@ class ExternalAppendOnlyMap[K, V, C](
if (objectsWritten == serializerBatchSize) {
flush()
curWriteMetrics = new ShuffleWriteMetrics()
writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize,
curWriteMetrics)
writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
}
}
if (objectsWritten > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ private[spark] class ExternalSorter[K, V, C](
// createTempShuffleBlock here; see SPARK-3426 for more context.
val (blockId, file) = diskBlockManager.createTempShuffleBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
var writer = blockManager.getDiskWriter(
blockId, file, serInstance, fileBufferSize, curWriteMetrics)
var objectsWritten = 0 // Objects written since the last flush

// List of batch sizes (bytes) in the order they are written to disk
Expand Down Expand Up @@ -308,7 +309,8 @@ private[spark] class ExternalSorter[K, V, C](
if (objectsWritten == serializerBatchSize) {
flush()
curWriteMetrics = new ShuffleWriteMetrics()
writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
writer = blockManager.getDiskWriter(
blockId, file, serInstance, fileBufferSize, curWriteMetrics)
}
}
if (objectsWritten > 0) {
Expand Down Expand Up @@ -358,7 +360,9 @@ private[spark] class ExternalSorter[K, V, C](
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
// createTempShuffleBlock here; see SPARK-3426 for more context.
val (blockId, file) = diskBlockManager.createTempShuffleBlock()
blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open()
val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize,
curWriteMetrics)
writer.open()
}
// Creating the file to write to and creating a disk writer both involve interacting with
// the disk, and can take a long time in aggregate when we open many files, so should be
Expand Down Expand Up @@ -749,8 +753,8 @@ private[spark] class ExternalSorter[K, V, C](
// partition and just write everything directly.
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
val writer = blockManager.getDiskWriter(
blockId, outputFile, ser, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get)
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics.shuffleWriteMetrics.get)
for (elem <- elements) {
writer.write(elem)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class BlockObjectWriterSuite extends FunSuite {
val file = new File(Utils.createTempDir(), "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)

writer.write(Long.box(20))
// Record metrics update on every write
Expand All @@ -52,7 +52,7 @@ class BlockObjectWriterSuite extends FunSuite {
val file = new File(Utils.createTempDir(), "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)

writer.write(Long.box(20))
// Record metrics update on every write
Expand All @@ -75,7 +75,7 @@ class BlockObjectWriterSuite extends FunSuite {
val file = new File(Utils.createTempDir(), "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)

writer.open()
writer.close()
Expand Down