18
18
package org .apache .spark .util
19
19
20
20
import java .io .{ByteArrayInputStream , ByteArrayOutputStream }
21
- import java .lang .invoke .SerializedLambda
21
+ import java .lang .invoke .{ MethodHandleInfo , SerializedLambda }
22
22
23
+ import scala .collection .JavaConverters ._
23
24
import scala .collection .mutable .{Map , Set , Stack }
24
25
25
- import org .apache .xbean .asm7 .{ClassReader , ClassVisitor , MethodVisitor , Type }
26
+ import org .apache .commons .lang3 .ClassUtils
27
+ import org .apache .xbean .asm7 .{ClassReader , ClassVisitor , Handle , MethodVisitor , Type }
26
28
import org .apache .xbean .asm7 .Opcodes ._
29
+ import org .apache .xbean .asm7 .tree .{ClassNode , MethodNode }
27
30
28
31
import org .apache .spark .{SparkEnv , SparkException }
29
32
import org .apache .spark .internal .Logging
@@ -159,39 +162,6 @@ private[spark] object ClosureCleaner extends Logging {
159
162
clean(closure, checkSerializable, cleanTransitively, Map .empty)
160
163
}
161
164
162
- /**
163
- * Try to get a serialized Lambda from the closure.
164
- *
165
- * @param closure the closure to check.
166
- */
167
- private def getSerializedLambda (closure : AnyRef ): Option [SerializedLambda ] = {
168
- val isClosureCandidate =
169
- closure.getClass.isSynthetic &&
170
- closure
171
- .getClass
172
- .getInterfaces.exists(_.getName == " scala.Serializable" )
173
-
174
- if (isClosureCandidate) {
175
- try {
176
- Option (inspect(closure))
177
- } catch {
178
- case e : Exception =>
179
- // no need to check if debug is enabled here the Spark
180
- // logging api covers this.
181
- logDebug(" Closure is not a serialized lambda." , e)
182
- None
183
- }
184
- } else {
185
- None
186
- }
187
- }
188
-
189
- private def inspect (closure : AnyRef ): SerializedLambda = {
190
- val writeReplace = closure.getClass.getDeclaredMethod(" writeReplace" )
191
- writeReplace.setAccessible(true )
192
- writeReplace.invoke(closure).asInstanceOf [java.lang.invoke.SerializedLambda ]
193
- }
194
-
195
165
/**
196
166
* Helper method to clean the given closure in place.
197
167
*
@@ -239,12 +209,12 @@ private[spark] object ClosureCleaner extends Logging {
239
209
cleanTransitively : Boolean ,
240
210
accessedFields : Map [Class [_], Set [String ]]): Unit = {
241
211
242
- // most likely to be the case with 2.12, 2.13
212
+ // indylambda check. Most likely to be the case with 2.12, 2.13
243
213
// so we check first
244
214
// non LMF-closures should be less frequent from now on
245
- val lambdaFunc = getSerializedLambda (func)
215
+ val maybeIndylambdaProxy = IndylambdaScalaClosures .getSerializationProxy (func)
246
216
247
- if (! isClosure(func.getClass) && lambdaFunc .isEmpty) {
217
+ if (! isClosure(func.getClass) && maybeIndylambdaProxy .isEmpty) {
248
218
logDebug(s " Expected a closure; got ${func.getClass.getName}" )
249
219
return
250
220
}
@@ -256,7 +226,7 @@ private[spark] object ClosureCleaner extends Logging {
256
226
return
257
227
}
258
228
259
- if (lambdaFunc .isEmpty) {
229
+ if (maybeIndylambdaProxy .isEmpty) {
260
230
logDebug(s " +++ Cleaning closure $func ( ${func.getClass.getName}) +++ " )
261
231
262
232
// A list of classes that represents closures enclosed in the given one
@@ -372,14 +342,60 @@ private[spark] object ClosureCleaner extends Logging {
372
342
373
343
logDebug(s " +++ closure $func ( ${func.getClass.getName}) is now cleaned +++ " )
374
344
} else {
375
- logDebug(s " Cleaning lambda: ${lambdaFunc.get.getImplMethodName}" )
345
+ val lambdaProxy = maybeIndylambdaProxy.get
346
+ val implMethodName = lambdaProxy.getImplMethodName
347
+
348
+ logDebug(s " Cleaning indylambda closure: $implMethodName" )
349
+
350
+ // capturing class is the class that declared this lambda
351
+ val capturingClassName = lambdaProxy.getCapturingClass.replace('/' , '.' )
352
+ val classLoader = func.getClass.getClassLoader // this is the safest option
353
+ // scalastyle:off classforname
354
+ val capturingClass = Class .forName(capturingClassName, false , classLoader)
355
+ // scalastyle:on classforname
376
356
377
- val captClass = Utils .classForName(lambdaFunc.get.getCapturingClass.replace('/' , '.' ),
378
- initialize = false , noSparkClassLoader = true )
379
357
// Fail fast if we detect return statements in closures
380
- getClassReader(captClass)
381
- .accept(new ReturnStatementFinder (Some (lambdaFunc.get.getImplMethodName)), 0 )
382
- logDebug(s " +++ Lambda closure ( ${lambdaFunc.get.getImplMethodName}) is now cleaned +++ " )
358
+ val capturingClassReader = getClassReader(capturingClass)
359
+ capturingClassReader.accept(new ReturnStatementFinder (Option (implMethodName)), 0 )
360
+
361
+ val isClosureDeclaredInScalaRepl = capturingClassName.startsWith(" $line" ) &&
362
+ capturingClassName.endsWith(" $iw" )
363
+ val outerThisOpt = if (lambdaProxy.getCapturedArgCount > 0 ) {
364
+ Option (lambdaProxy.getCapturedArg(0 ))
365
+ } else {
366
+ None
367
+ }
368
+
369
+ // only need to clean when there is an enclosing "this" captured by the closure, and it
370
+ // should be something cleanable, i.e. a Scala REPL line object
371
+ val needsCleaning = isClosureDeclaredInScalaRepl &&
372
+ outerThisOpt.isDefined && outerThisOpt.get.getClass.getName == capturingClassName
373
+
374
+ if (needsCleaning) {
375
+ assert(accessedFields.isEmpty)
376
+
377
+ initAccessedFields(accessedFields, Seq (capturingClass))
378
+ IndylambdaScalaClosures .findAccessedFields(lambdaProxy, classLoader, accessedFields)
379
+
380
+ logDebug(s " + fields accessed by starting closure: " + accessedFields.size)
381
+ accessedFields.foreach { f => logDebug(" " + f) }
382
+
383
+ if (accessedFields(capturingClass).size < capturingClass.getDeclaredFields.length) {
384
+ // clone and clean the enclosing `this` only when there are fields to null out
385
+
386
+ val outerThis = outerThisOpt.get
387
+
388
+ logDebug(s " + cloning instance of REPL class $capturingClassName" )
389
+ val clonedOuterThis = cloneAndSetFields(
390
+ parent = null , outerThis, capturingClass, accessedFields)
391
+
392
+ val outerField = func.getClass.getDeclaredField(" arg$1" )
393
+ outerField.setAccessible(true )
394
+ outerField.set(func, clonedOuterThis)
395
+ }
396
+ }
397
+
398
+ logDebug(s " +++ indylambda closure ( $implMethodName) is now cleaned +++ " )
383
399
}
384
400
385
401
if (checkSerializable) {
@@ -414,6 +430,139 @@ private[spark] object ClosureCleaner extends Logging {
414
430
}
415
431
}
416
432
433
+ private [spark] object IndylambdaScalaClosures extends Logging {
434
+ // internal name of java.lang.invoke.LambdaMetafactory
435
+ val LambdaMetafactoryClassName = " java/lang/invoke/LambdaMetafactory"
436
+ // the method that Scala indylambda use for bootstrap method
437
+ val LambdaMetafactoryMethodName = " altMetafactory"
438
+ val LambdaMetafactoryMethodDesc = " (Ljava/lang/invoke/MethodHandles$Lookup;" +
439
+ " Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)" +
440
+ " Ljava/lang/invoke/CallSite;"
441
+
442
+ /**
443
+ * Check if the given reference is a indylambda style Scala closure.
444
+ * If so, return a non-empty serialization proxy (SerializedLambda) of the closure;
445
+ * otherwise return None.
446
+ *
447
+ * @param maybeClosure the closure to check.
448
+ */
449
+ def getSerializationProxy (maybeClosure : AnyRef ): Option [SerializedLambda ] = {
450
+ val maybeClosureClass = maybeClosure.getClass
451
+
452
+ // shortcut the fast check:
453
+ // indylambda closure classes are generated by Java's LambdaMetafactory, and they're always
454
+ // synthetic.
455
+ if (! maybeClosureClass.isSynthetic) return None
456
+
457
+ val implementedInterfaces = ClassUtils .getAllInterfaces(maybeClosureClass).asScala
458
+ val isClosureCandidate = implementedInterfaces.exists(_.getName == " scala.Serializable" ) &&
459
+ implementedInterfaces.exists(_.getName.startsWith(" scala.Function" ))
460
+
461
+ if (isClosureCandidate) {
462
+ try {
463
+ val lambdaProxy = inspect(maybeClosure)
464
+ if (isIndylambdaScalaClosure(lambdaProxy)) Option (lambdaProxy)
465
+ else None
466
+ } catch {
467
+ case e : Exception =>
468
+ // no need to check if debug is enabled here the Spark logging api covers this.
469
+ logDebug(" The given reference is not an indylambda Scala closure." , e)
470
+ None
471
+ }
472
+ } else {
473
+ None
474
+ }
475
+ }
476
+
477
+ def isIndylambdaScalaClosure (lambdaProxy : SerializedLambda ): Boolean = {
478
+ lambdaProxy.getImplMethodKind == MethodHandleInfo .REF_invokeStatic &&
479
+ lambdaProxy.getImplMethodName.contains(" $anonfun$" )
480
+ // && implements a scala.runtime.java8 functional interface
481
+ }
482
+
483
+ def inspect (closure : AnyRef ): SerializedLambda = {
484
+ val writeReplace = closure.getClass.getDeclaredMethod(" writeReplace" )
485
+ writeReplace.setAccessible(true )
486
+ writeReplace.invoke(closure).asInstanceOf [SerializedLambda ]
487
+ }
488
+
489
+ def findAccessedFields (
490
+ lambdaProxy : SerializedLambda ,
491
+ lambdaClassLoader : ClassLoader ,
492
+ accessedFields : Map [Class [_], Set [String ]]): Unit = {
493
+ val implClassInternalName = lambdaProxy.getImplClass
494
+ // scalastyle:off classforname
495
+ val implClass = Class .forName(
496
+ implClassInternalName.replace('/' , '.' ), false , lambdaClassLoader)
497
+ // scalastyle:on classforname
498
+ val implClassNode = new ClassNode ()
499
+ val implClassReader = ClosureCleaner .getClassReader(implClass)
500
+ implClassReader.accept(implClassNode, 0 )
501
+
502
+ val methodsByName = Map .empty[MethodIdentifier [_], MethodNode ]
503
+ for (m <- implClassNode.methods.asScala) {
504
+ methodsByName(MethodIdentifier (implClass, m.name, m.desc)) = m
505
+ }
506
+
507
+ val implMethodId = MethodIdentifier (
508
+ implClass, lambdaProxy.getImplMethodName, lambdaProxy.getImplMethodSignature)
509
+ val implMethodNode = methodsByName(implMethodId)
510
+
511
+ val visited = Set [MethodIdentifier [_]](implMethodId)
512
+ val stack = Stack [MethodIdentifier [_]](implMethodId)
513
+ while (! stack.isEmpty) {
514
+ val currentId = stack.pop
515
+ val currentMethodNode = methodsByName(currentId)
516
+ logTrace(s " scanning $currentId" )
517
+ currentMethodNode.accept(new MethodVisitor (ASM7 ) {
518
+ override def visitFieldInsn (op : Int , owner : String , name : String , desc : String ): Unit = {
519
+ if (op == GETFIELD || op == PUTFIELD ) {
520
+ val ownerExternalName = owner.replace('/' , '.' )
521
+ for (cl <- accessedFields.keys if cl.getName == ownerExternalName) {
522
+ logTrace(s " found field access $name on $owner" )
523
+ accessedFields(cl) += name
524
+ }
525
+ }
526
+ }
527
+
528
+ override def visitMethodInsn (
529
+ op : Int , owner : String , name : String , desc : String , itf : Boolean ): Unit = {
530
+ if (owner == implClassInternalName) {
531
+ logTrace(s " found intra class call to $owner. $name$desc" )
532
+ stack.push(MethodIdentifier (implClass, name, desc))
533
+ } else {
534
+ // keep the same behavior as the original ClosureCleaner
535
+ logTrace(s " ignoring call to $owner. $name$desc" )
536
+ }
537
+ }
538
+
539
+ // find the lexically nested closures
540
+ override def visitInvokeDynamicInsn (
541
+ name : String , desc : String , bsmHandle : Handle , bsmArgs : Object * ): Unit = {
542
+ logTrace(s " invokedynamic: $name$desc, bsmHandle= $bsmHandle, bsmArgs= $bsmArgs" )
543
+
544
+ // fast check: we only care about Scala lambda creation
545
+ if (! name.startsWith(" apply" )) return
546
+ if (! Type .getReturnType(desc).getDescriptor.startsWith(" Lscala/Function" )) return
547
+
548
+ if (bsmHandle.getOwner == LambdaMetafactoryClassName &&
549
+ bsmHandle.getName == LambdaMetafactoryMethodName &&
550
+ bsmHandle.getDesc == LambdaMetafactoryMethodDesc ) {
551
+ // OK we're in the right bootstrap method for serializable Java 8 style lambda creation
552
+ val targetHandle = bsmArgs(1 ).asInstanceOf [Handle ]
553
+ if (targetHandle.getOwner == implClassInternalName &&
554
+ targetHandle.getDesc.startsWith(s " (L $implClassInternalName; " )) {
555
+ // this is a lexically nested closure that also captures the enclosing `this`
556
+ logDebug(s " found inner closure $targetHandle" )
557
+ stack.push(MethodIdentifier (implClass, targetHandle.getName, targetHandle.getDesc))
558
+ }
559
+ }
560
+ }
561
+ })
562
+ }
563
+ }
564
+ }
565
+
417
566
private [spark] class ReturnStatementInClosureException
418
567
extends SparkException (" Return statements aren't allowed in Spark closures" )
419
568
0 commit comments