Skip to content

Commit 9da4b6b

Browse files
JoshRosenpwendell
authored andcommitted
[SPARK-7873] Allow KryoSerializerInstance to create multiple streams at the same time
This is a somewhat obscure bug, but I think that it will seriously impact KryoSerializer users who use custom registrators which disabled auto-reset. When auto-reset is disabled, then this breaks things in some of our shuffle paths which actually end up creating multiple OutputStreams from the same shared SerializerInstance (which is unsafe). This was introduced by a patch (SPARK-3386) which enables serializer re-use in some of the shuffle paths, since constructing new serializer instances is actually pretty costly for KryoSerializer. We had already fixed another corner-case (SPARK-7766) bug related to this, but missed this one. I think that the root problem here is that KryoSerializerInstance can be used in a way which is unsafe even within a single thread, e.g. by creating multiple open OutputStreams from the same instance or by interleaving deserialize and deserializeStream calls. I considered a smaller patch which adds assertions to guard against this type of "misuse" but abandoned that approach after I realized how convoluted the Scaladoc became. This patch fixes this bug by making it legal to create multiple streams from the same KryoSerializerInstance. Internally, KryoSerializerInstance now implements a `borrowKryo()` / `releaseKryo()` API that's backed by a "pool" of capacity 1. Each call to a KryoSerializerInstance method will borrow the Kryo, do its work, then release the serializer instance back to the pool. If the pool is empty and we need an instance, it will allocate a new Kryo on-demand. This makes it safe for multiple OutputStreams to be opened from the same serializer. If we try to release a Kryo back to the pool but the pool already contains a Kryo, then we'll just discard the new Kryo. I don't think there's a clear benefit to having a larger pool since our usages tend to fall into two cases, a) where we only create a single OutputStream and b) where we create a huge number of OutputStreams with the same lifecycle, then destroy the KryoSerializerInstance (this is what's happening in the bypassMergeSort code path that my regression test hits). Author: Josh Rosen <joshrosen@databricks.com> Closes #6415 from JoshRosen/SPARK-7873 and squashes the following commits: 00b402e [Josh Rosen] Initialize eagerly to fix a failing test ba55d20 [Josh Rosen] Add explanatory comments 3f1da96 [Josh Rosen] Guard against duplicate close() ab457ca [Josh Rosen] Sketch a loan/release based solution. 9816e8f [Josh Rosen] Add a failing test showing how deserialize() and deserializeStream() can interfere. 7350886 [Josh Rosen] Add failing regression test for SPARK-7873 (cherry picked from commit 852f4de) Signed-off-by: Patrick Wendell <patrick@databricks.com>
1 parent bd9173c commit 9da4b6b

File tree

3 files changed

+147
-24
lines changed

3 files changed

+147
-24
lines changed

core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala

Lines changed: 106 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
package org.apache.spark.serializer
1919

20-
import java.io.{EOFException, InputStream, OutputStream}
20+
import java.io.{EOFException, IOException, InputStream, OutputStream}
2121
import java.nio.ByteBuffer
22+
import javax.annotation.Nullable
2223

2324
import scala.reflect.ClassTag
2425

@@ -136,21 +137,45 @@ class KryoSerializer(conf: SparkConf)
136137
}
137138

138139
private[spark]
139-
class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {
140-
val output = new KryoOutput(outStream)
140+
class KryoSerializationStream(
141+
serInstance: KryoSerializerInstance,
142+
outStream: OutputStream) extends SerializationStream {
143+
144+
private[this] var output: KryoOutput = new KryoOutput(outStream)
145+
private[this] var kryo: Kryo = serInstance.borrowKryo()
141146

142147
override def writeObject[T: ClassTag](t: T): SerializationStream = {
143148
kryo.writeClassAndObject(output, t)
144149
this
145150
}
146151

147-
override def flush() { output.flush() }
148-
override def close() { output.close() }
152+
override def flush() {
153+
if (output == null) {
154+
throw new IOException("Stream is closed")
155+
}
156+
output.flush()
157+
}
158+
159+
override def close() {
160+
if (output != null) {
161+
try {
162+
output.close()
163+
} finally {
164+
serInstance.releaseKryo(kryo)
165+
kryo = null
166+
output = null
167+
}
168+
}
169+
}
149170
}
150171

151172
private[spark]
152-
class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {
153-
private val input = new KryoInput(inStream)
173+
class KryoDeserializationStream(
174+
serInstance: KryoSerializerInstance,
175+
inStream: InputStream) extends DeserializationStream {
176+
177+
private[this] var input: KryoInput = new KryoInput(inStream)
178+
private[this] var kryo: Kryo = serInstance.borrowKryo()
154179

155180
override def readObject[T: ClassTag](): T = {
156181
try {
@@ -163,52 +188,105 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser
163188
}
164189

165190
override def close() {
166-
// Kryo's Input automatically closes the input stream it is using.
167-
input.close()
191+
if (input != null) {
192+
try {
193+
// Kryo's Input automatically closes the input stream it is using.
194+
input.close()
195+
} finally {
196+
serInstance.releaseKryo(kryo)
197+
kryo = null
198+
input = null
199+
}
200+
}
168201
}
169202
}
170203

171204
private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
172-
private val kryo = ks.newKryo()
173205

174-
// Make these lazy vals to avoid creating a buffer unless we use them
206+
/**
207+
* A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do
208+
* their work, then release the instance by calling `releaseKryo()`. Logically, this is a caching
209+
* pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are
210+
* not synchronized.
211+
*/
212+
@Nullable private[this] var cachedKryo: Kryo = borrowKryo()
213+
214+
/**
215+
* Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance;
216+
* otherwise, it allocates a new instance.
217+
*/
218+
private[serializer] def borrowKryo(): Kryo = {
219+
if (cachedKryo != null) {
220+
val kryo = cachedKryo
221+
// As a defensive measure, call reset() to clear any Kryo state that might have been modified
222+
// by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue)
223+
kryo.reset()
224+
cachedKryo = null
225+
kryo
226+
} else {
227+
ks.newKryo()
228+
}
229+
}
230+
231+
/**
232+
* Release a borrowed [[Kryo]] instance. If this serializer instance already has a cached Kryo
233+
* instance, then the given Kryo instance is discarded; otherwise, the Kryo is stored for later
234+
* re-use.
235+
*/
236+
private[serializer] def releaseKryo(kryo: Kryo): Unit = {
237+
if (cachedKryo == null) {
238+
cachedKryo = kryo
239+
}
240+
}
241+
242+
// Make these lazy vals to avoid creating a buffer unless we use them.
175243
private lazy val output = ks.newKryoOutput()
176244
private lazy val input = new KryoInput()
177245

178246
override def serialize[T: ClassTag](t: T): ByteBuffer = {
179247
output.clear()
180-
kryo.reset() // We must reset in case this serializer instance was reused (see SPARK-7766)
248+
val kryo = borrowKryo()
181249
try {
182250
kryo.writeClassAndObject(output, t)
183251
} catch {
184252
case e: KryoException if e.getMessage.startsWith("Buffer overflow") =>
185253
throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " +
186254
"increase spark.kryoserializer.buffer.max value.")
255+
} finally {
256+
releaseKryo(kryo)
187257
}
188258
ByteBuffer.wrap(output.toBytes)
189259
}
190260

191261
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
192-
input.setBuffer(bytes.array)
193-
kryo.readClassAndObject(input).asInstanceOf[T]
262+
val kryo = borrowKryo()
263+
try {
264+
input.setBuffer(bytes.array)
265+
kryo.readClassAndObject(input).asInstanceOf[T]
266+
} finally {
267+
releaseKryo(kryo)
268+
}
194269
}
195270

196271
override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
272+
val kryo = borrowKryo()
197273
val oldClassLoader = kryo.getClassLoader
198-
kryo.setClassLoader(loader)
199-
input.setBuffer(bytes.array)
200-
val obj = kryo.readClassAndObject(input).asInstanceOf[T]
201-
kryo.setClassLoader(oldClassLoader)
202-
obj
274+
try {
275+
kryo.setClassLoader(loader)
276+
input.setBuffer(bytes.array)
277+
kryo.readClassAndObject(input).asInstanceOf[T]
278+
} finally {
279+
kryo.setClassLoader(oldClassLoader)
280+
releaseKryo(kryo)
281+
}
203282
}
204283

205284
override def serializeStream(s: OutputStream): SerializationStream = {
206-
kryo.reset() // We must reset in case this serializer instance was reused (see SPARK-7766)
207-
new KryoSerializationStream(kryo, s)
285+
new KryoSerializationStream(this, s)
208286
}
209287

210288
override def deserializeStream(s: InputStream): DeserializationStream = {
211-
new KryoDeserializationStream(kryo, s)
289+
new KryoDeserializationStream(this, s)
212290
}
213291

214292
/**
@@ -218,7 +296,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
218296
def getAutoReset(): Boolean = {
219297
val field = classOf[Kryo].getDeclaredField("autoReset")
220298
field.setAccessible(true)
221-
field.get(kryo).asInstanceOf[Boolean]
299+
val kryo = borrowKryo()
300+
try {
301+
field.get(kryo).asInstanceOf[Boolean]
302+
} finally {
303+
releaseKryo(kryo)
304+
}
222305
}
223306
}
224307

core/src/main/scala/org/apache/spark/serializer/Serializer.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.serializer
1919

2020
import java.io._
2121
import java.nio.ByteBuffer
22+
import javax.annotation.concurrent.NotThreadSafe
2223

2324
import scala.reflect.ClassTag
2425

@@ -114,8 +115,12 @@ object Serializer {
114115
/**
115116
* :: DeveloperApi ::
116117
* An instance of a serializer, for use by one thread at a time.
118+
*
119+
* It is legal to create multiple serialization / deserialization streams from the same
120+
* SerializerInstance as long as those streams are all used within the same thread.
117121
*/
118122
@DeveloperApi
123+
@NotThreadSafe
119124
abstract class SerializerInstance {
120125
def serialize[T: ClassTag](t: T): ByteBuffer
121126

core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.serializer
1919

20-
import java.io.ByteArrayOutputStream
20+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
2121

2222
import scala.collection.mutable
2323
import scala.reflect.ClassTag
@@ -354,6 +354,41 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
354354
}
355355
}
356356

357+
class KryoSerializerAutoResetDisabledSuite extends FunSuite with SharedSparkContext {
358+
conf.set("spark.serializer", classOf[KryoSerializer].getName)
359+
conf.set("spark.kryo.registrator", classOf[RegistratorWithoutAutoReset].getName)
360+
conf.set("spark.kryo.referenceTracking", "true")
361+
conf.set("spark.shuffle.manager", "sort")
362+
conf.set("spark.shuffle.sort.bypassMergeThreshold", "200")
363+
364+
test("sort-shuffle with bypassMergeSort (SPARK-7873)") {
365+
val myObject = ("Hello", "World")
366+
assert(sc.parallelize(Seq.fill(100)(myObject)).repartition(2).collect().toSet === Set(myObject))
367+
}
368+
369+
test("calling deserialize() after deserializeStream()") {
370+
val serInstance = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance]
371+
assert(!serInstance.getAutoReset())
372+
val hello = "Hello"
373+
val world = "World"
374+
// Here, we serialize the same value twice, so the reference-tracking should cause us to store
375+
// references to some of these values
376+
val helloHello = serInstance.serialize((hello, hello))
377+
// Here's a stream which only contains one value
378+
val worldWorld: Array[Byte] = {
379+
val baos = new ByteArrayOutputStream()
380+
val serStream = serInstance.serializeStream(baos)
381+
serStream.writeObject(world)
382+
serStream.writeObject(world)
383+
serStream.close()
384+
baos.toByteArray
385+
}
386+
val deserializationStream = serInstance.deserializeStream(new ByteArrayInputStream(worldWorld))
387+
assert(deserializationStream.readValue[Any]() === world)
388+
deserializationStream.close()
389+
assert(serInstance.deserialize[Any](helloHello) === (hello, hello))
390+
}
391+
}
357392

358393
class ClassLoaderTestingObject
359394

0 commit comments

Comments
 (0)