Skip to content

[SPARK-7873] Allow KryoSerializerInstance to create multiple streams at the same time #6415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 106 additions & 23 deletions core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

package org.apache.spark.serializer

import java.io.{EOFException, InputStream, OutputStream}
import java.io.{EOFException, IOException, InputStream, OutputStream}
import java.nio.ByteBuffer
import javax.annotation.Nullable

import scala.reflect.ClassTag

Expand Down Expand Up @@ -136,21 +137,45 @@ class KryoSerializer(conf: SparkConf)
}

private[spark]
class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {
val output = new KryoOutput(outStream)
class KryoSerializationStream(
serInstance: KryoSerializerInstance,
outStream: OutputStream) extends SerializationStream {

private[this] var output: KryoOutput = new KryoOutput(outStream)
private[this] var kryo: Kryo = serInstance.borrowKryo()

override def writeObject[T: ClassTag](t: T): SerializationStream = {
kryo.writeClassAndObject(output, t)
this
}

override def flush() { output.flush() }
override def close() { output.close() }
override def flush() {
if (output == null) {
throw new IOException("Stream is closed")
}
output.flush()
}

override def close() {
if (output != null) {
try {
output.close()
} finally {
serInstance.releaseKryo(kryo)
kryo = null
output = null
}
}
}
}

private[spark]
class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {
private val input = new KryoInput(inStream)
class KryoDeserializationStream(
serInstance: KryoSerializerInstance,
inStream: InputStream) extends DeserializationStream {

private[this] var input: KryoInput = new KryoInput(inStream)
private[this] var kryo: Kryo = serInstance.borrowKryo()

override def readObject[T: ClassTag](): T = {
try {
Expand All @@ -163,52 +188,105 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser
}

override def close() {
// Kryo's Input automatically closes the input stream it is using.
input.close()
if (input != null) {
try {
// Kryo's Input automatically closes the input stream it is using.
input.close()
} finally {
serInstance.releaseKryo(kryo)
kryo = null
input = null
}
}
}
}

private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
private val kryo = ks.newKryo()

// Make these lazy vals to avoid creating a buffer unless we use them
/**
* A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do
* their work, then release the instance by calling `releaseKryo()`. Logically, this is a caching
* pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are
* not synchronized.
*/
@Nullable private[this] var cachedKryo: Kryo = borrowKryo()

/**
* Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance;
* otherwise, it allocates a new instance.
*/
private[serializer] def borrowKryo(): Kryo = {
if (cachedKryo != null) {
val kryo = cachedKryo
// As a defensive measure, call reset() to clear any Kryo state that might have been modified
// by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue)
kryo.reset()
cachedKryo = null
kryo
} else {
ks.newKryo()
}
}

/**
* Release a borrowed [[Kryo]] instance. If this serializer instance already has a cached Kryo
* instance, then the given Kryo instance is discarded; otherwise, the Kryo is stored for later
* re-use.
*/
private[serializer] def releaseKryo(kryo: Kryo): Unit = {
if (cachedKryo == null) {
cachedKryo = kryo
}
}

// Make these lazy vals to avoid creating a buffer unless we use them.
private lazy val output = ks.newKryoOutput()
private lazy val input = new KryoInput()

override def serialize[T: ClassTag](t: T): ByteBuffer = {
output.clear()
kryo.reset() // We must reset in case this serializer instance was reused (see SPARK-7766)
val kryo = borrowKryo()
try {
kryo.writeClassAndObject(output, t)
} catch {
case e: KryoException if e.getMessage.startsWith("Buffer overflow") =>
throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " +
"increase spark.kryoserializer.buffer.max value.")
} finally {
releaseKryo(kryo)
}
ByteBuffer.wrap(output.toBytes)
}

override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
input.setBuffer(bytes.array)
kryo.readClassAndObject(input).asInstanceOf[T]
val kryo = borrowKryo()
try {
input.setBuffer(bytes.array)
kryo.readClassAndObject(input).asInstanceOf[T]
} finally {
releaseKryo(kryo)
}
}

override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
val kryo = borrowKryo()
val oldClassLoader = kryo.getClassLoader
kryo.setClassLoader(loader)
input.setBuffer(bytes.array)
val obj = kryo.readClassAndObject(input).asInstanceOf[T]
kryo.setClassLoader(oldClassLoader)
obj
try {
kryo.setClassLoader(loader)
input.setBuffer(bytes.array)
kryo.readClassAndObject(input).asInstanceOf[T]
} finally {
kryo.setClassLoader(oldClassLoader)
releaseKryo(kryo)
}
}

override def serializeStream(s: OutputStream): SerializationStream = {
kryo.reset() // We must reset in case this serializer instance was reused (see SPARK-7766)
new KryoSerializationStream(kryo, s)
new KryoSerializationStream(this, s)
}

override def deserializeStream(s: InputStream): DeserializationStream = {
new KryoDeserializationStream(kryo, s)
new KryoDeserializationStream(this, s)
}

/**
Expand All @@ -218,7 +296,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
def getAutoReset(): Boolean = {
val field = classOf[Kryo].getDeclaredField("autoReset")
field.setAccessible(true)
field.get(kryo).asInstanceOf[Boolean]
val kryo = borrowKryo()
try {
field.get(kryo).asInstanceOf[Boolean]
} finally {
releaseKryo(kryo)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.serializer

import java.io._
import java.nio.ByteBuffer
import javax.annotation.concurrent.NotThreadSafe

import scala.reflect.ClassTag

Expand Down Expand Up @@ -114,8 +115,12 @@ object Serializer {
/**
* :: DeveloperApi ::
* An instance of a serializer, for use by one thread at a time.
*
* It is legal to create multiple serialization / deserialization streams from the same
* SerializerInstance as long as those streams are all used within the same thread.
*/
@DeveloperApi
@NotThreadSafe
abstract class SerializerInstance {
def serialize[T: ClassTag](t: T): ByteBuffer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.serializer

import java.io.ByteArrayOutputStream
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import scala.collection.mutable
import scala.reflect.ClassTag
Expand Down Expand Up @@ -361,6 +361,41 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
}
}

class KryoSerializerAutoResetDisabledSuite extends FunSuite with SharedSparkContext {
conf.set("spark.serializer", classOf[KryoSerializer].getName)
conf.set("spark.kryo.registrator", classOf[RegistratorWithoutAutoReset].getName)
conf.set("spark.kryo.referenceTracking", "true")
conf.set("spark.shuffle.manager", "sort")
conf.set("spark.shuffle.sort.bypassMergeThreshold", "200")

test("sort-shuffle with bypassMergeSort (SPARK-7873)") {
val myObject = ("Hello", "World")
assert(sc.parallelize(Seq.fill(100)(myObject)).repartition(2).collect().toSet === Set(myObject))
}

test("calling deserialize() after deserializeStream()") {
val serInstance = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance]
assert(!serInstance.getAutoReset())
val hello = "Hello"
val world = "World"
// Here, we serialize the same value twice, so the reference-tracking should cause us to store
// references to some of these values
val helloHello = serInstance.serialize((hello, hello))
// Here's a stream which only contains one value
val worldWorld: Array[Byte] = {
val baos = new ByteArrayOutputStream()
val serStream = serInstance.serializeStream(baos)
serStream.writeObject(world)
serStream.writeObject(world)
serStream.close()
baos.toByteArray
}
val deserializationStream = serInstance.deserializeStream(new ByteArrayInputStream(worldWorld))
assert(deserializationStream.readValue[Any]() === world)
deserializationStream.close()
assert(serInstance.deserialize[Any](helloHello) === (hello, hello))
}
}

class ClassLoaderTestingObject

Expand Down