Skip to content

Commit 7420185

Browse files
committed
Allow user to pass Serializer object instead of class name for shuffle.
This is more general than simply passing a string name and leaves more room for performance optimizations.
1 parent e19044c commit 7420185

File tree

18 files changed

+98
-162
lines changed

18 files changed

+98
-162
lines changed

core/src/main/scala/org/apache/spark/Dependency.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark
1919

2020
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.serializer.Serializer
2122

2223
/**
2324
* Base class for dependencies.
@@ -43,12 +44,12 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
4344
* Represents a dependency on the output of a shuffle stage.
4445
* @param rdd the parent RDD
4546
* @param partitioner partitioner used to partition the shuffle output
46-
* @param serializerClass class name of the serializer to use
47+
* @param serializer [[Serializer]] to use.
4748
*/
4849
class ShuffleDependency[K, V](
4950
@transient rdd: RDD[_ <: Product2[K, V]],
5051
val partitioner: Partitioner,
51-
val serializerClass: String = null)
52+
val serializer: Serializer = null)
5253
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
5354

5455
val shuffleId: Int = rdd.context.newShuffleId()

core/src/main/scala/org/apache/spark/ShuffleFetcher.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ private[spark] abstract class ShuffleFetcher {
2929
shuffleId: Int,
3030
reduceId: Int,
3131
context: TaskContext,
32-
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
32+
serializer: Serializer = SparkEnv.get.serializer): Iterator[T]
3333

3434
/** Stop the fetcher */
3535
def stop() {}

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.broadcast.BroadcastManager
2828
import org.apache.spark.metrics.MetricsSystem
2929
import org.apache.spark.storage.{BlockManager, BlockManagerMaster, BlockManagerMasterActor}
3030
import org.apache.spark.network.ConnectionManager
31-
import org.apache.spark.serializer.{Serializer, SerializerManager}
31+
import org.apache.spark.serializer.Serializer
3232
import org.apache.spark.util.{AkkaUtils, Utils}
3333

3434
/**
@@ -41,7 +41,6 @@ import org.apache.spark.util.{AkkaUtils, Utils}
4141
class SparkEnv private[spark] (
4242
val executorId: String,
4343
val actorSystem: ActorSystem,
44-
val serializerManager: SerializerManager,
4544
val serializer: Serializer,
4645
val closureSerializer: Serializer,
4746
val cacheManager: CacheManager,
@@ -141,14 +140,12 @@ object SparkEnv extends Logging {
141140
val name = conf.get(propertyName, defaultClassName)
142141
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
143142
}
144-
val serializerManager = new SerializerManager
145143

146-
val serializer = serializerManager.setDefault(
147-
conf.get("spark.serializer", "org.apache.spark.serializer.JavaSerializer"), conf)
144+
val serializer = instantiateClass[Serializer](
145+
"spark.serializer", "org.apache.spark.serializer.JavaSerializer")
148146

149-
val closureSerializer = serializerManager.get(
150-
conf.get("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"),
151-
conf)
147+
val closureSerializer = instantiateClass[Serializer](
148+
"spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")
152149

153150
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
154151
if (isDriver) {
@@ -220,7 +217,6 @@ object SparkEnv extends Logging {
220217
new SparkEnv(
221218
executorId,
222219
actorSystem,
223-
serializerManager,
224220
serializer,
225221
closureSerializer,
226222
cacheManager,

core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
2424
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
2525
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
2626
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
27+
import org.apache.spark.serializer.Serializer
2728

2829
private[spark] sealed trait CoGroupSplitDep extends Serializable
2930

@@ -66,10 +67,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
6667
private type CoGroupValue = (Any, Int) // Int is dependency number
6768
private type CoGroupCombiner = Seq[CoGroup]
6869

69-
private var serializerClass: String = null
70+
private var serializer: Serializer = null
7071

71-
def setSerializer(cls: String): CoGroupedRDD[K] = {
72-
serializerClass = cls
72+
def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
73+
this.serializer = serializer
7374
this
7475
}
7576

@@ -80,7 +81,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
8081
new OneToOneDependency(rdd)
8182
} else {
8283
logDebug("Adding shuffle dependency with " + rdd)
83-
new ShuffleDependency[Any, Any](rdd, part, serializerClass)
84+
new ShuffleDependency[Any, Any](rdd, part, serializer)
8485
}
8586
}
8687
}
@@ -113,18 +114,17 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
113114
// A list of (rdd iterator, dependency number) pairs
114115
val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
115116
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
116-
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
117+
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
117118
// Read them from the parent
118119
val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
119120
rddIterators += ((it, depNum))
120-
}
121-
case ShuffleCoGroupSplitDep(shuffleId) => {
121+
122+
case ShuffleCoGroupSplitDep(shuffleId) =>
122123
// Read map outputs of shuffle
123124
val fetcher = SparkEnv.get.shuffleFetcher
124-
val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf)
125+
val ser = Serializer.getSerializer(serializer)
125126
val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
126127
rddIterators += ((it, depNum))
127-
}
128128
}
129129

130130
if (!externalSorting) {

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import org.apache.spark._
4444
import org.apache.spark.Partitioner.defaultPartitioner
4545
import org.apache.spark.SparkContext._
4646
import org.apache.spark.partial.{BoundedDouble, PartialResult}
47+
import org.apache.spark.serializer.Serializer
4748
import org.apache.spark.util.SerializableHyperLogLog
4849

4950
/**
@@ -73,7 +74,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
7374
mergeCombiners: (C, C) => C,
7475
partitioner: Partitioner,
7576
mapSideCombine: Boolean = true,
76-
serializerClass: String = null): RDD[(K, C)] = {
77+
serializer: Serializer = null): RDD[(K, C)] = {
7778
require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0
7879
if (getKeyClass().isArray) {
7980
if (mapSideCombine) {
@@ -93,13 +94,13 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
9394
aggregator.combineValuesByKey(iter, context)
9495
}, preservesPartitioning = true)
9596
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
96-
.setSerializer(serializerClass)
97+
.setSerializer(serializer)
9798
partitioned.mapPartitionsWithContext((context, iter) => {
9899
new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context))
99100
}, preservesPartitioning = true)
100101
} else {
101102
// Don't apply map-side combiner.
102-
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
103+
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializer)
103104
values.mapPartitionsWithContext((context, iter) => {
104105
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
105106
}, preservesPartitioning = true)

core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.rdd
2020
import scala.reflect.ClassTag
2121

2222
import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext}
23+
import org.apache.spark.serializer.Serializer
2324

2425
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
2526
override val index = idx
@@ -38,15 +39,15 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
3839
part: Partitioner)
3940
extends RDD[P](prev.context, Nil) {
4041

41-
private var serializerClass: String = null
42+
private var serializer: Serializer = null
4243

43-
def setSerializer(cls: String): ShuffledRDD[K, V, P] = {
44-
serializerClass = cls
44+
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
45+
this.serializer = serializer
4546
this
4647
}
4748

4849
override def getDependencies: Seq[Dependency[_]] = {
49-
List(new ShuffleDependency(prev, part, serializerClass))
50+
List(new ShuffleDependency(prev, part, serializer))
5051
}
5152

5253
override val partitioner = Some(part)
@@ -57,8 +58,8 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
5758

5859
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
5960
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
60-
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
61-
SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf))
61+
val ser = Serializer.getSerializer(serializer)
62+
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
6263
}
6364

6465
override def clearDependencies() {

core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.Partitioner
3030
import org.apache.spark.ShuffleDependency
3131
import org.apache.spark.SparkEnv
3232
import org.apache.spark.TaskContext
33+
import org.apache.spark.serializer.Serializer
3334

3435
/**
3536
* An optimized version of cogroup for set difference/subtraction.
@@ -53,10 +54,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
5354
part: Partitioner)
5455
extends RDD[(K, V)](rdd1.context, Nil) {
5556

56-
private var serializerClass: String = null
57+
private var serializer: Serializer = null
5758

58-
def setSerializer(cls: String): SubtractedRDD[K, V, W] = {
59-
serializerClass = cls
59+
def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
60+
this.serializer = serializer
6061
this
6162
}
6263

@@ -67,7 +68,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
6768
new OneToOneDependency(rdd)
6869
} else {
6970
logDebug("Adding shuffle dependency with " + rdd)
70-
new ShuffleDependency(rdd, part, serializerClass)
71+
new ShuffleDependency(rdd, part, serializer)
7172
}
7273
}
7374
}
@@ -92,7 +93,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
9293

9394
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
9495
val partition = p.asInstanceOf[CoGroupPartition]
95-
val serializer = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)
96+
val ser = Serializer.getSerializer(serializer)
9697
val map = new JHashMap[K, ArrayBuffer[V]]
9798
def getSeq(k: K): ArrayBuffer[V] = {
9899
val seq = map.get(k)
@@ -105,14 +106,13 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
105106
}
106107
}
107108
def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match {
108-
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
109+
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
109110
rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
110-
}
111-
case ShuffleCoGroupSplitDep(shuffleId) => {
111+
112+
case ShuffleCoGroupSplitDep(shuffleId) =>
112113
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
113-
context, serializer)
114+
context, ser)
114115
iter.foreach(op)
115-
}
116116
}
117117
// the first dep is rdd1; add all values to the map
118118
integrate(partition.deps(0), t => getSeq(t._1) += t._2)

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark._
2626
import org.apache.spark.executor.ShuffleWriteMetrics
2727
import org.apache.spark.rdd.RDD
2828
import org.apache.spark.rdd.RDDCheckpointData
29+
import org.apache.spark.serializer.Serializer
2930
import org.apache.spark.storage._
3031
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
3132

@@ -153,7 +154,7 @@ private[spark] class ShuffleMapTask(
153154

154155
try {
155156
// Obtain all the block writers for shuffle blocks.
156-
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass, SparkEnv.get.conf)
157+
val ser = Serializer.getSerializer(dep.serializer)
157158
shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
158159

159160
// Write the map output to its associated buckets.

core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,6 @@ private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerI
9797
/**
9898
* A Spark serializer that uses Java's built-in serialization.
9999
*/
100-
class JavaSerializer(conf: SparkConf) extends Serializer {
100+
class JavaSerializer(conf: SparkConf) extends Serializer with Serializable {
101101
def newInstance(): SerializerInstance = new JavaSerializerInstance(conf)
102102
}

core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}
3434
/**
3535
* A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
3636
*/
37-
class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging {
38-
private val bufferSize = {
39-
conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024
40-
}
37+
class KryoSerializer(conf: SparkConf)
38+
extends org.apache.spark.serializer.Serializer
39+
with Logging
40+
with Serializable {
41+
42+
private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024
43+
private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true)
44+
private val registrator = conf.getOption("spark.kryo.registrator")
4145

4246
def newKryoOutput() = new KryoOutput(bufferSize)
4347

@@ -48,7 +52,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial
4852

4953
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
5054
// Do this before we invoke the user registrator so the user registrator can override this.
51-
kryo.setReferences(conf.getBoolean("spark.kryo.referenceTracking", true))
55+
kryo.setReferences(referenceTracking)
5256

5357
for (cls <- KryoSerializer.toRegister) kryo.register(cls)
5458

@@ -58,7 +62,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial
5862

5963
// Allow the user to register their own classes by setting spark.kryo.registrator
6064
try {
61-
for (regCls <- conf.getOption("spark.kryo.registrator")) {
65+
for (regCls <- registrator) {
6266
logDebug("Running user registrator: " + regCls)
6367
val reg = Class.forName(regCls, true, classLoader).newInstance()
6468
.asInstanceOf[KryoRegistrator]

0 commit comments

Comments
 (0)