Skip to content

Commit 0899401

Browse files
committed
SPARK-31399: support indylambda Scala closure cleaning in ClosureCleaner
1 parent f05560b commit 0899401

File tree

1 file changed

+194
-45
lines changed

1 file changed

+194
-45
lines changed

core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala

+194-45
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818
package org.apache.spark.util
1919

2020
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
21-
import java.lang.invoke.SerializedLambda
21+
import java.lang.invoke.{MethodHandleInfo, SerializedLambda}
2222

23+
import scala.collection.JavaConverters._
2324
import scala.collection.mutable.{Map, Set, Stack}
2425

25-
import org.apache.xbean.asm7.{ClassReader, ClassVisitor, MethodVisitor, Type}
26+
import org.apache.commons.lang3.ClassUtils
27+
import org.apache.xbean.asm7.{ClassReader, ClassVisitor, Handle, MethodVisitor, Type}
2628
import org.apache.xbean.asm7.Opcodes._
29+
import org.apache.xbean.asm7.tree.{ClassNode, MethodNode}
2730

2831
import org.apache.spark.{SparkEnv, SparkException}
2932
import org.apache.spark.internal.Logging
@@ -159,39 +162,6 @@ private[spark] object ClosureCleaner extends Logging {
159162
clean(closure, checkSerializable, cleanTransitively, Map.empty)
160163
}
161164

162-
/**
163-
* Try to get a serialized Lambda from the closure.
164-
*
165-
* @param closure the closure to check.
166-
*/
167-
private def getSerializedLambda(closure: AnyRef): Option[SerializedLambda] = {
168-
val isClosureCandidate =
169-
closure.getClass.isSynthetic &&
170-
closure
171-
.getClass
172-
.getInterfaces.exists(_.getName == "scala.Serializable")
173-
174-
if (isClosureCandidate) {
175-
try {
176-
Option(inspect(closure))
177-
} catch {
178-
case e: Exception =>
179-
// no need to check if debug is enabled here the Spark
180-
// logging api covers this.
181-
logDebug("Closure is not a serialized lambda.", e)
182-
None
183-
}
184-
} else {
185-
None
186-
}
187-
}
188-
189-
private def inspect(closure: AnyRef): SerializedLambda = {
190-
val writeReplace = closure.getClass.getDeclaredMethod("writeReplace")
191-
writeReplace.setAccessible(true)
192-
writeReplace.invoke(closure).asInstanceOf[java.lang.invoke.SerializedLambda]
193-
}
194-
195165
/**
196166
* Helper method to clean the given closure in place.
197167
*
@@ -239,12 +209,12 @@ private[spark] object ClosureCleaner extends Logging {
239209
cleanTransitively: Boolean,
240210
accessedFields: Map[Class[_], Set[String]]): Unit = {
241211

242-
// most likely to be the case with 2.12, 2.13
212+
// indylambda check. Most likely to be the case with 2.12, 2.13
243213
// so we check first
244214
// non LMF-closures should be less frequent from now on
245-
val lambdaFunc = getSerializedLambda(func)
215+
val maybeIndylambdaProxy = IndylambdaScalaClosures.getSerializationProxy(func)
246216

247-
if (!isClosure(func.getClass) && lambdaFunc.isEmpty) {
217+
if (!isClosure(func.getClass) && maybeIndylambdaProxy.isEmpty) {
248218
logDebug(s"Expected a closure; got ${func.getClass.getName}")
249219
return
250220
}
@@ -256,7 +226,7 @@ private[spark] object ClosureCleaner extends Logging {
256226
return
257227
}
258228

259-
if (lambdaFunc.isEmpty) {
229+
if (maybeIndylambdaProxy.isEmpty) {
260230
logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++")
261231

262232
// A list of classes that represents closures enclosed in the given one
@@ -372,14 +342,60 @@ private[spark] object ClosureCleaner extends Logging {
372342

373343
logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++")
374344
} else {
375-
logDebug(s"Cleaning lambda: ${lambdaFunc.get.getImplMethodName}")
345+
val lambdaProxy = maybeIndylambdaProxy.get
346+
val implMethodName = lambdaProxy.getImplMethodName
347+
348+
logDebug(s"Cleaning indylambda closure: $implMethodName")
349+
350+
// capturing class is the class that declared this lambda
351+
val capturingClassName = lambdaProxy.getCapturingClass.replace('/', '.')
352+
val classLoader = func.getClass.getClassLoader // this is the safest option
353+
// scalastyle:off classforname
354+
val capturingClass = Class.forName(capturingClassName, false, classLoader)
355+
// scalastyle:on classforname
376356

377-
val captClass = Utils.classForName(lambdaFunc.get.getCapturingClass.replace('/', '.'),
378-
initialize = false, noSparkClassLoader = true)
379357
// Fail fast if we detect return statements in closures
380-
getClassReader(captClass)
381-
.accept(new ReturnStatementFinder(Some(lambdaFunc.get.getImplMethodName)), 0)
382-
logDebug(s" +++ Lambda closure (${lambdaFunc.get.getImplMethodName}) is now cleaned +++")
358+
val capturingClassReader = getClassReader(capturingClass)
359+
capturingClassReader.accept(new ReturnStatementFinder(Option(implMethodName)), 0)
360+
361+
val isClosureDeclaredInScalaRepl = capturingClassName.startsWith("$line") &&
362+
capturingClassName.endsWith("$iw")
363+
val outerThisOpt = if (lambdaProxy.getCapturedArgCount > 0) {
364+
Option(lambdaProxy.getCapturedArg(0))
365+
} else {
366+
None
367+
}
368+
369+
// only need to clean when there is an enclosing "this" captured by the closure, and it
370+
// should be something cleanable, i.e. a Scala REPL line object
371+
val needsCleaning = isClosureDeclaredInScalaRepl &&
372+
outerThisOpt.isDefined && outerThisOpt.get.getClass.getName == capturingClassName
373+
374+
if (needsCleaning) {
375+
assert(accessedFields.isEmpty)
376+
377+
initAccessedFields(accessedFields, Seq(capturingClass))
378+
IndylambdaScalaClosures.findAccessedFields(lambdaProxy, classLoader, accessedFields)
379+
380+
logDebug(s" + fields accessed by starting closure: " + accessedFields.size)
381+
accessedFields.foreach { f => logDebug(" " + f) }
382+
383+
if (accessedFields(capturingClass).size < capturingClass.getDeclaredFields.length) {
384+
// clone and clean the enclosing `this` only when there are fields to null out
385+
386+
val outerThis = outerThisOpt.get
387+
388+
logDebug(s" + cloning instance of REPL class $capturingClassName")
389+
val clonedOuterThis = cloneAndSetFields(
390+
parent = null, outerThis, capturingClass, accessedFields)
391+
392+
val outerField = func.getClass.getDeclaredField("arg$1")
393+
outerField.setAccessible(true)
394+
outerField.set(func, clonedOuterThis)
395+
}
396+
}
397+
398+
logDebug(s" +++ indylambda closure ($implMethodName) is now cleaned +++")
383399
}
384400

385401
if (checkSerializable) {
@@ -414,6 +430,139 @@ private[spark] object ClosureCleaner extends Logging {
414430
}
415431
}
416432

433+
private[spark] object IndylambdaScalaClosures extends Logging {
434+
// internal name of java.lang.invoke.LambdaMetafactory
435+
val LambdaMetafactoryClassName = "java/lang/invoke/LambdaMetafactory"
436+
// the method that Scala indylambda use for bootstrap method
437+
val LambdaMetafactoryMethodName = "altMetafactory"
438+
val LambdaMetafactoryMethodDesc = "(Ljava/lang/invoke/MethodHandles$Lookup;" +
439+
"Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)" +
440+
"Ljava/lang/invoke/CallSite;"
441+
442+
/**
443+
* Check if the given reference is a indylambda style Scala closure.
444+
* If so, return a non-empty serialization proxy (SerializedLambda) of the closure;
445+
* otherwise return None.
446+
*
447+
* @param maybeClosure the closure to check.
448+
*/
449+
def getSerializationProxy(maybeClosure: AnyRef): Option[SerializedLambda] = {
450+
val maybeClosureClass = maybeClosure.getClass
451+
452+
// shortcut the fast check:
453+
// indylambda closure classes are generated by Java's LambdaMetafactory, and they're always
454+
// synthetic.
455+
if (!maybeClosureClass.isSynthetic) return None
456+
457+
val implementedInterfaces = ClassUtils.getAllInterfaces(maybeClosureClass).asScala
458+
val isClosureCandidate = implementedInterfaces.exists(_.getName == "scala.Serializable") &&
459+
implementedInterfaces.exists(_.getName.startsWith("scala.Function"))
460+
461+
if (isClosureCandidate) {
462+
try {
463+
val lambdaProxy = inspect(maybeClosure)
464+
if (isIndylambdaScalaClosure(lambdaProxy)) Option(lambdaProxy)
465+
else None
466+
} catch {
467+
case e: Exception =>
468+
// no need to check if debug is enabled here the Spark logging api covers this.
469+
logDebug("The given reference is not an indylambda Scala closure.", e)
470+
None
471+
}
472+
} else {
473+
None
474+
}
475+
}
476+
477+
def isIndylambdaScalaClosure(lambdaProxy: SerializedLambda): Boolean = {
478+
lambdaProxy.getImplMethodKind == MethodHandleInfo.REF_invokeStatic &&
479+
lambdaProxy.getImplMethodName.contains("$anonfun$")
480+
// && implements a scala.runtime.java8 functional interface
481+
}
482+
483+
def inspect(closure: AnyRef): SerializedLambda = {
484+
val writeReplace = closure.getClass.getDeclaredMethod("writeReplace")
485+
writeReplace.setAccessible(true)
486+
writeReplace.invoke(closure).asInstanceOf[SerializedLambda]
487+
}
488+
489+
def findAccessedFields(
490+
lambdaProxy: SerializedLambda,
491+
lambdaClassLoader: ClassLoader,
492+
accessedFields: Map[Class[_], Set[String]]): Unit = {
493+
val implClassInternalName = lambdaProxy.getImplClass
494+
// scalastyle:off classforname
495+
val implClass = Class.forName(
496+
implClassInternalName.replace('/', '.'), false, lambdaClassLoader)
497+
// scalastyle:on classforname
498+
val implClassNode = new ClassNode()
499+
val implClassReader = ClosureCleaner.getClassReader(implClass)
500+
implClassReader.accept(implClassNode, 0)
501+
502+
val methodsByName = Map.empty[MethodIdentifier[_], MethodNode]
503+
for (m <- implClassNode.methods.asScala) {
504+
methodsByName(MethodIdentifier(implClass, m.name, m.desc)) = m
505+
}
506+
507+
val implMethodId = MethodIdentifier(
508+
implClass, lambdaProxy.getImplMethodName, lambdaProxy.getImplMethodSignature)
509+
val implMethodNode = methodsByName(implMethodId)
510+
511+
val visited = Set[MethodIdentifier[_]](implMethodId)
512+
val stack = Stack[MethodIdentifier[_]](implMethodId)
513+
while (!stack.isEmpty) {
514+
val currentId = stack.pop
515+
val currentMethodNode = methodsByName(currentId)
516+
logTrace(s" scanning $currentId")
517+
currentMethodNode.accept(new MethodVisitor(ASM7) {
518+
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String): Unit = {
519+
if (op == GETFIELD || op == PUTFIELD) {
520+
val ownerExternalName = owner.replace('/', '.')
521+
for (cl <- accessedFields.keys if cl.getName == ownerExternalName) {
522+
logTrace(s" found field access $name on $owner")
523+
accessedFields(cl) += name
524+
}
525+
}
526+
}
527+
528+
override def visitMethodInsn(
529+
op: Int, owner: String, name: String, desc: String, itf: Boolean): Unit = {
530+
if (owner == implClassInternalName) {
531+
logTrace(s" found intra class call to $owner.$name$desc")
532+
stack.push(MethodIdentifier(implClass, name, desc))
533+
} else {
534+
// keep the same behavior as the original ClosureCleaner
535+
logTrace(s" ignoring call to $owner.$name$desc")
536+
}
537+
}
538+
539+
// find the lexically nested closures
540+
override def visitInvokeDynamicInsn(
541+
name: String, desc: String, bsmHandle: Handle, bsmArgs: Object*): Unit = {
542+
logTrace(s" invokedynamic: $name$desc, bsmHandle=$bsmHandle, bsmArgs=$bsmArgs")
543+
544+
// fast check: we only care about Scala lambda creation
545+
if (!name.startsWith("apply")) return
546+
if (!Type.getReturnType(desc).getDescriptor.startsWith("Lscala/Function")) return
547+
548+
if (bsmHandle.getOwner == LambdaMetafactoryClassName &&
549+
bsmHandle.getName == LambdaMetafactoryMethodName &&
550+
bsmHandle.getDesc == LambdaMetafactoryMethodDesc) {
551+
// OK we're in the right bootstrap method for serializable Java 8 style lambda creation
552+
val targetHandle = bsmArgs(1).asInstanceOf[Handle]
553+
if (targetHandle.getOwner == implClassInternalName &&
554+
targetHandle.getDesc.startsWith(s"(L$implClassInternalName;")) {
555+
// this is a lexically nested closure that also captures the enclosing `this`
556+
logDebug(s" found inner closure $targetHandle")
557+
stack.push(MethodIdentifier(implClass, targetHandle.getName, targetHandle.getDesc))
558+
}
559+
}
560+
}
561+
})
562+
}
563+
}
564+
}
565+
417566
private[spark] class ReturnStatementInClosureException
418567
extends SparkException("Return statements aren't allowed in Spark closures")
419568

0 commit comments

Comments
 (0)