Skip to content

[SPARK-11195][CORE] Use correct classloader for TaskResultGetter #9779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions core/src/main/scala/org/apache/spark/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
import java.net.{URI, URL}
import java.nio.charset.StandardCharsets
import java.nio.file.Paths
import java.util.Arrays
import java.util.jar.{JarEntry, JarOutputStream}

Expand Down Expand Up @@ -83,15 +84,15 @@ private[spark] object TestUtils {
}

/**
* Create a jar file that contains this set of files. All files will be located at the root
* of the jar.
* Create a jar file that contains this set of files. All files will be located in the specified
* directory or at the root of the jar.
*/
def createJar(files: Seq[File], jarFile: File): URL = {
def createJar(files: Seq[File], jarFile: File, directoryPrefix: Option[String] = None): URL = {
val jarFileStream = new FileOutputStream(jarFile)
val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest())

for (file <- files) {
val jarEntry = new JarEntry(file.getName)
val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString)
jarStream.putNextEntry(jarEntry)

val in = new FileInputStream(file)
Expand Down Expand Up @@ -123,7 +124,7 @@ private[spark] object TestUtils {
classpathUrls: Seq[URL]): File = {
val compiler = ToolProvider.getSystemJavaCompiler

// Calling this outputs a class file in pwd. It's easier to just rename the file than
// Calling this outputs a class file in pwd. It's easier to just rename the files than
// build a custom FileManager that controls the output location.
val options = if (classpathUrls.nonEmpty) {
Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
try {
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
val loader = Utils.getContextOrSparkClassLoader
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskEndReason](
serializedData, Utils.getSparkClassLoader)
serializedData, loader)
}
} catch {
case cnd: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastrophic
// if we can't deserialize the reason.
val loader = Utils.getContextOrSparkClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
case ex: Exception => {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.scheduler

import java.io.File
import java.net.URL
import java.nio.ByteBuffer

import scala.concurrent.duration._
Expand All @@ -26,8 +28,10 @@ import scala.util.control.NonFatal
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
import org.apache.spark._
import org.apache.spark.storage.TaskResultBlockId
import org.apache.spark.TestUtils.JavaSourceFromString
import org.apache.spark.util.{MutableURLClassLoader, Utils}

/**
* Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
Expand Down Expand Up @@ -119,5 +123,64 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
// Make sure two tasks were run (one failed one, and a second retried one).
assert(scheduler.nextTaskId.get() === 2)
}

/**
* Make sure we are using the context classloader when deserializing failed TaskResults instead
* of the Spark classloader.

* This test compiles a jar containing an exception and tests that when it is thrown on the
* executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown
* exception as the cause.

* Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing
* the exception, resulting in an UnknownReason for the TaskEndResult.
*/
test("failed task deserialized with the correct classloader (SPARK-11195)") {
// compile a small jar containing an exception that will be thrown on an executor.
val tempDir = Utils.createTempDir()
val srcDir = new File(tempDir, "repro/")
srcDir.mkdirs()
val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath,
"""package repro;
|
|public class MyException extends Exception {
|}
""".stripMargin)
val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty)
val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro"))

// ensure we reset the classloader after the test completes
val originalClassLoader = Thread.currentThread.getContextClassLoader
try {
// load the exception from the jar
val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader)
loader.addURL(jarFile.toURI.toURL)
Thread.currentThread().setContextClassLoader(loader)
val excClass: Class[_] = Utils.classForName("repro.MyException")

// NOTE: we must run the cluster with "local" so that the executor can load the compiled
// jar.
sc = new SparkContext("local", "test", conf)
val rdd = sc.parallelize(Seq(1), 1).map { _ =>
val exc = excClass.newInstance().asInstanceOf[Exception]
throw exc
}

// the driver should not have any problems resolving the exception class and determining
// why the task failed.
val exceptionMessage = intercept[SparkException] {
rdd.collect()
}.getMessage

val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r
val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r

assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined)
assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty)
} finally {
Thread.currentThread.setContextClassLoader(originalClassLoader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to clean up the dirs and jars?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think so - createTempDir adds a shutdown hook to delete the directory

}
}
}