Skip to content

Commit 4f8dc6b

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-22328][CORE] ClosureCleaner should not miss referenced superclass fields
## What changes were proposed in this pull request? When the given closure uses some fields defined in super class, `ClosureCleaner` can't figure them and don't set it properly. Those fields will be in null values. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes apache#19556 from viirya/SPARK-22328.
1 parent 0e9a750 commit 4f8dc6b

File tree

2 files changed

+133
-12
lines changed

2 files changed

+133
-12
lines changed

core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,54 @@ private[spark] object ClosureCleaner extends Logging {
9191
(seen - obj.getClass).toList
9292
}
9393

94+
/** Initializes the accessed fields for outer classes and their super classes. */
95+
private def initAccessedFields(
96+
accessedFields: Map[Class[_], Set[String]],
97+
outerClasses: Seq[Class[_]]): Unit = {
98+
for (cls <- outerClasses) {
99+
var currentClass = cls
100+
assert(currentClass != null, "The outer class can't be null.")
101+
102+
while (currentClass != null) {
103+
accessedFields(currentClass) = Set.empty[String]
104+
currentClass = currentClass.getSuperclass()
105+
}
106+
}
107+
}
108+
109+
/** Sets accessed fields for given class in clone object based on given object. */
110+
private def setAccessedFields(
111+
outerClass: Class[_],
112+
clone: AnyRef,
113+
obj: AnyRef,
114+
accessedFields: Map[Class[_], Set[String]]): Unit = {
115+
for (fieldName <- accessedFields(outerClass)) {
116+
val field = outerClass.getDeclaredField(fieldName)
117+
field.setAccessible(true)
118+
val value = field.get(obj)
119+
field.set(clone, value)
120+
}
121+
}
122+
123+
/** Clones a given object and sets accessed fields in cloned object. */
124+
private def cloneAndSetFields(
125+
parent: AnyRef,
126+
obj: AnyRef,
127+
outerClass: Class[_],
128+
accessedFields: Map[Class[_], Set[String]]): AnyRef = {
129+
val clone = instantiateClass(outerClass, parent)
130+
131+
var currentClass = outerClass
132+
assert(currentClass != null, "The outer class can't be null.")
133+
134+
while (currentClass != null) {
135+
setAccessedFields(currentClass, clone, obj, accessedFields)
136+
currentClass = currentClass.getSuperclass()
137+
}
138+
139+
clone
140+
}
141+
94142
/**
95143
* Clean the given closure in place.
96144
*
@@ -202,9 +250,8 @@ private[spark] object ClosureCleaner extends Logging {
202250
logDebug(s" + populating accessed fields because this is the starting closure")
203251
// Initialize accessed fields with the outer classes first
204252
// This step is needed to associate the fields to the correct classes later
205-
for (cls <- outerClasses) {
206-
accessedFields(cls) = Set.empty[String]
207-
}
253+
initAccessedFields(accessedFields, outerClasses)
254+
208255
// Populate accessed fields by visiting all fields and methods accessed by this and
209256
// all of its inner closures. If transitive cleaning is enabled, this may recursively
210257
// visits methods that belong to other classes in search of transitively referenced fields.
@@ -250,13 +297,8 @@ private[spark] object ClosureCleaner extends Logging {
250297
// required fields from the original object. We need the parent here because the Java
251298
// language specification requires the first constructor parameter of any closure to be
252299
// its enclosing object.
253-
val clone = instantiateClass(cls, parent)
254-
for (fieldName <- accessedFields(cls)) {
255-
val field = cls.getDeclaredField(fieldName)
256-
field.setAccessible(true)
257-
val value = field.get(obj)
258-
field.set(clone, value)
259-
}
300+
val clone = cloneAndSetFields(parent, obj, cls, accessedFields)
301+
260302
// If transitive cleaning is enabled, we recursively clean any enclosing closure using
261303
// the already populated accessed fields map of the starting closure
262304
if (cleanTransitively && isClosure(clone.getClass)) {
@@ -395,8 +437,15 @@ private[util] class FieldAccessFinder(
395437
if (!visitedMethods.contains(m)) {
396438
// Keep track of visited methods to avoid potential infinite cycles
397439
visitedMethods += m
398-
ClosureCleaner.getClassReader(cl).accept(
399-
new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)
440+
441+
var currentClass = cl
442+
assert(currentClass != null, "The outer class can't be null.")
443+
444+
while (currentClass != null) {
445+
ClosureCleaner.getClassReader(currentClass).accept(
446+
new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)
447+
currentClass = currentClass.getSuperclass()
448+
}
400449
}
401450
}
402451
}

core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,63 @@ class ClosureCleanerSuite extends SparkFunSuite {
119119
test("createNullValue") {
120120
new TestCreateNullValue().run()
121121
}
122+
123+
test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") {
124+
val concreteObject = new TestAbstractClass {
125+
val n2 = 222
126+
val s2 = "bbb"
127+
val d2 = 2.0d
128+
129+
def run(): Seq[(Int, Int, String, String, Double, Double)] = {
130+
withSpark(new SparkContext("local", "test")) { sc =>
131+
val rdd = sc.parallelize(1 to 1)
132+
body(rdd)
133+
}
134+
}
135+
136+
def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] = rdd.map { _ =>
137+
(n1, n2, s1, s2, d1, d2)
138+
}.collect()
139+
}
140+
assert(concreteObject.run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d)))
141+
}
142+
143+
test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") {
144+
val concreteObject = new TestAbstractClass2 {
145+
val n2 = 222
146+
val s2 = "bbb"
147+
val d2 = 2.0d
148+
def getData: Int => (Int, Int, String, String, Double, Double) = _ => (n1, n2, s1, s2, d1, d2)
149+
}
150+
withSpark(new SparkContext("local", "test")) { sc =>
151+
val rdd = sc.parallelize(1 to 1).map(concreteObject.getData)
152+
assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d)))
153+
}
154+
}
155+
156+
test("SPARK-22328: multiple outer classes have the same parent class") {
157+
val concreteObject = new TestAbstractClass2 {
158+
159+
val innerObject = new TestAbstractClass2 {
160+
override val n1 = 222
161+
override val s1 = "bbb"
162+
}
163+
164+
val innerObject2 = new TestAbstractClass2 {
165+
override val n1 = 444
166+
val n3 = 333
167+
val s3 = "ccc"
168+
val d3 = 3.0d
169+
170+
def getData: Int => (Int, Int, String, String, Double, Double, Int, String) =
171+
_ => (n1, n3, s1, s3, d1, d3, innerObject.n1, innerObject.s1)
172+
}
173+
}
174+
withSpark(new SparkContext("local", "test")) { sc =>
175+
val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData)
176+
assert(rdd.collect() === Seq((444, 333, "aaa", "ccc", 1.0d, 3.0d, 222, "bbb")))
177+
}
178+
}
122179
}
123180

124181
// A non-serializable class we create in closures to make sure that we aren't
@@ -377,3 +434,18 @@ class TestCreateNullValue {
377434
nestedClosure()
378435
}
379436
}
437+
438+
abstract class TestAbstractClass extends Serializable {
439+
val n1 = 111
440+
val s1 = "aaa"
441+
protected val d1 = 1.0d
442+
443+
def run(): Seq[(Int, Int, String, String, Double, Double)]
444+
def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)]
445+
}
446+
447+
abstract class TestAbstractClass2 extends Serializable {
448+
val n1 = 111
449+
val s1 = "aaa"
450+
protected val d1 = 1.0d
451+
}

0 commit comments

Comments
 (0)