Skip to content

[SPARK-22328][Core] ClosureCleaner should not miss referenced superclass fields #19556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 61 additions & 12 deletions core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,54 @@ private[spark] object ClosureCleaner extends Logging {
(seen - obj.getClass).toList
}

/** Initializes the accessed fields for outer classes and their super classes. */
private def initAccessedFields(
accessedFields: Map[Class[_], Set[String]],
outerClasses: Seq[Class[_]]): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if multiple outer classes have the same parent class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think from the view of closure, even multiple outer classes have the same parent class, the access of the fields in the parent class shouldn't conflict.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a related test. Please see if it can clarify your concern.

for (cls <- outerClasses) {
var currentClass = cls
assert(currentClass != null, "The outer class can't be null.")

while (currentClass != null) {
accessedFields(currentClass) = Set.empty[String]
currentClass = currentClass.getSuperclass()
}
}
}

/** Sets accessed fields for given class in clone object based on given object. */
private def setAccessedFields(
outerClass: Class[_],
clone: AnyRef,
obj: AnyRef,
accessedFields: Map[Class[_], Set[String]]): Unit = {
for (fieldName <- accessedFields(outerClass)) {
val field = outerClass.getDeclaredField(fieldName)
field.setAccessible(true)
val value = field.get(obj)
field.set(clone, value)
}
}

/** Clones a given object and sets accessed fields in cloned object. */
private def cloneAndSetFields(
parent: AnyRef,
obj: AnyRef,
outerClass: Class[_],
accessedFields: Map[Class[_], Set[String]]): AnyRef = {
val clone = instantiateClass(outerClass, parent)

var currentClass = outerClass
assert(currentClass != null, "The outer class can't be null.")

while (currentClass != null) {
setAccessedFields(currentClass, clone, obj, accessedFields)
currentClass = currentClass.getSuperclass()
}

clone
}

/**
* Clean the given closure in place.
*
Expand Down Expand Up @@ -202,9 +250,8 @@ private[spark] object ClosureCleaner extends Logging {
logDebug(s" + populating accessed fields because this is the starting closure")
// Initialize accessed fields with the outer classes first
// This step is needed to associate the fields to the correct classes later
for (cls <- outerClasses) {
accessedFields(cls) = Set.empty[String]
}
initAccessedFields(accessedFields, outerClasses)

// Populate accessed fields by visiting all fields and methods accessed by this and
// all of its inner closures. If transitive cleaning is enabled, this may recursively
// visits methods that belong to other classes in search of transitively referenced fields.
Expand Down Expand Up @@ -250,13 +297,8 @@ private[spark] object ClosureCleaner extends Logging {
// required fields from the original object. We need the parent here because the Java
// language specification requires the first constructor parameter of any closure to be
// its enclosing object.
val clone = instantiateClass(cls, parent)
for (fieldName <- accessedFields(cls)) {
val field = cls.getDeclaredField(fieldName)
field.setAccessible(true)
val value = field.get(obj)
field.set(clone, value)
}
val clone = cloneAndSetFields(parent, obj, cls, accessedFields)

// If transitive cleaning is enabled, we recursively clean any enclosing closure using
// the already populated accessed fields map of the starting closure
if (cleanTransitively && isClosure(clone.getClass)) {
Expand Down Expand Up @@ -395,8 +437,15 @@ private[util] class FieldAccessFinder(
if (!visitedMethods.contains(m)) {
// Keep track of visited methods to avoid potential infinite cycles
visitedMethods += m
ClosureCleaner.getClassReader(cl).accept(
new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)

var currentClass = cl
assert(currentClass != null, "The outer class can't be null.")

while (currentClass != null) {
ClosureCleaner.getClassReader(currentClass).accept(
new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)
currentClass = currentClass.getSuperclass()
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,63 @@ class ClosureCleanerSuite extends SparkFunSuite {
test("createNullValue") {
new TestCreateNullValue().run()
}

test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") {
val concreteObject = new TestAbstractClass {
val n2 = 222
val s2 = "bbb"
val d2 = 2.0d

def run(): Seq[(Int, Int, String, String, Double, Double)] = {
withSpark(new SparkContext("local", "test")) { sc =>
val rdd = sc.parallelize(1 to 1)
body(rdd)
}
}

def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] = rdd.map { _ =>
(n1, n2, s1, s2, d1, d2)
}.collect()
}
assert(concreteObject.run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d)))
}

test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") {
val concreteObject = new TestAbstractClass2 {
val n2 = 222
val s2 = "bbb"
val d2 = 2.0d
def getData: Int => (Int, Int, String, String, Double, Double) = _ => (n1, n2, s1, s2, d1, d2)
}
withSpark(new SparkContext("local", "test")) { sc =>
val rdd = sc.parallelize(1 to 1).map(concreteObject.getData)
assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d)))
}
}

test("SPARK-22328: multiple outer classes have the same parent class") {
val concreteObject = new TestAbstractClass2 {

val innerObject = new TestAbstractClass2 {
override val n1 = 222
override val s1 = "bbb"
}

val innerObject2 = new TestAbstractClass2 {
override val n1 = 444
val n3 = 333
val s3 = "ccc"
val d3 = 3.0d

def getData: Int => (Int, Int, String, String, Double, Double, Int, String) =
_ => (n1, n3, s1, s3, d1, d3, innerObject.n1, innerObject.s1)
}
}
withSpark(new SparkContext("local", "test")) { sc =>
val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData)
assert(rdd.collect() === Seq((444, 333, "aaa", "ccc", 1.0d, 3.0d, 222, "bbb")))
}
}
}

// A non-serializable class we create in closures to make sure that we aren't
Expand Down Expand Up @@ -377,3 +434,18 @@ class TestCreateNullValue {
nestedClosure()
}
}

abstract class TestAbstractClass extends Serializable {
val n1 = 111
val s1 = "aaa"
protected val d1 = 1.0d

def run(): Seq[(Int, Int, String, String, Double, Double)]
def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)]
}

abstract class TestAbstractClass2 extends Serializable {
val n1 = 111
val s1 = "aaa"
protected val d1 = 1.0d
}