Skip to content

Commit f802b07

Browse files
Hurshal Patelyhuai
Hurshal Patel
authored andcommitted
[SPARK-11195][CORE] Use correct classloader for TaskResultGetter
Make sure we are using the context classloader when deserializing failed TaskResults instead of the Spark classloader. The issue is that `enqueueFailedTask` was using the incorrect classloader which results in `ClassNotFoundException`. Adds a test in TaskResultGetterSuite that compiles a custom exception, throws it on the executor, and asserts that Spark handles the TaskResult deserialization instead of returning `UnknownReason`. See #9367 for previous comments See SPARK-11195 for a full repro Author: Hurshal Patel <hpatel516@gmail.com> Closes #9779 from choochootrain/spark-11195-master. (cherry picked from commit 3cca5ff) Signed-off-by: Yin Huai <yhuai@databricks.com> Conflicts: core/src/main/scala/org/apache/spark/TestUtils.scala
1 parent 4b8dc25 commit f802b07

File tree

3 files changed

+72
-8
lines changed

3 files changed

+72
-8
lines changed

core/src/main/scala/org/apache/spark/TestUtils.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark
1919

2020
import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
2121
import java.net.{URI, URL}
22+
import java.nio.file.Paths
2223
import java.util.jar.{JarEntry, JarOutputStream}
2324

2425
import scala.collection.JavaConversions._
@@ -78,15 +79,15 @@ private[spark] object TestUtils {
7879
}
7980

8081
/**
81-
* Create a jar file that contains this set of files. All files will be located at the root
82-
* of the jar.
82+
* Create a jar file that contains this set of files. All files will be located in the specified
83+
* directory or at the root of the jar.
8384
*/
84-
def createJar(files: Seq[File], jarFile: File): URL = {
85+
def createJar(files: Seq[File], jarFile: File, directoryPrefix: Option[String] = None): URL = {
8586
val jarFileStream = new FileOutputStream(jarFile)
8687
val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest())
8788

8889
for (file <- files) {
89-
val jarEntry = new JarEntry(file.getName)
90+
val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString)
9091
jarStream.putNextEntry(jarEntry)
9192

9293
val in = new FileInputStream(file)
@@ -118,7 +119,7 @@ private[spark] object TestUtils {
118119
classpathUrls: Seq[URL]): File = {
119120
val compiler = ToolProvider.getSystemJavaCompiler
120121

121-
// Calling this outputs a class file in pwd. It's easier to just rename the file than
122+
// Calling this outputs a class file in pwd. It's easier to just rename the files than
122123
// build a custom FileManager that controls the output location.
123124
val options = if (classpathUrls.nonEmpty) {
124125
Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator))

core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
103103
try {
104104
getTaskResultExecutor.execute(new Runnable {
105105
override def run(): Unit = Utils.logUncaughtExceptions {
106+
val loader = Utils.getContextOrSparkClassLoader
106107
try {
107108
if (serializedData != null && serializedData.limit() > 0) {
108109
reason = serializer.get().deserialize[TaskEndReason](
109-
serializedData, Utils.getSparkClassLoader)
110+
serializedData, loader)
110111
}
111112
} catch {
112113
case cnd: ClassNotFoundException =>
113114
// Log an error but keep going here -- the task failed, so not catastrophic
114115
// if we can't deserialize the reason.
115-
val loader = Utils.getContextOrSparkClassLoader
116116
logError(
117117
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
118118
case ex: Exception => {}

core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.scheduler
1919

20+
import java.io.File
21+
import java.net.URL
2022
import java.nio.ByteBuffer
2123

2224
import scala.concurrent.duration._
@@ -26,8 +28,10 @@ import scala.util.control.NonFatal
2628
import org.scalatest.BeforeAndAfter
2729
import org.scalatest.concurrent.Eventually._
2830

29-
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
31+
import org.apache.spark._
3032
import org.apache.spark.storage.TaskResultBlockId
33+
import org.apache.spark.TestUtils.JavaSourceFromString
34+
import org.apache.spark.util.{MutableURLClassLoader, Utils}
3135

3236
/**
3337
* Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
@@ -119,5 +123,64 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
119123
// Make sure two tasks were run (one failed one, and a second retried one).
120124
assert(scheduler.nextTaskId.get() === 2)
121125
}
126+
127+
/**
128+
* Make sure we are using the context classloader when deserializing failed TaskResults instead
129+
* of the Spark classloader.
130+
131+
* This test compiles a jar containing an exception and tests that when it is thrown on the
132+
* executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown
133+
* exception as the cause.
134+
135+
* Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing
136+
* the exception, resulting in an UnknownReason for the TaskEndResult.
137+
*/
138+
test("failed task deserialized with the correct classloader (SPARK-11195)") {
139+
// compile a small jar containing an exception that will be thrown on an executor.
140+
val tempDir = Utils.createTempDir()
141+
val srcDir = new File(tempDir, "repro/")
142+
srcDir.mkdirs()
143+
val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath,
144+
"""package repro;
145+
|
146+
|public class MyException extends Exception {
147+
|}
148+
""".stripMargin)
149+
val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty)
150+
val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
151+
TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro"))
152+
153+
// ensure we reset the classloader after the test completes
154+
val originalClassLoader = Thread.currentThread.getContextClassLoader
155+
try {
156+
// load the exception from the jar
157+
val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader)
158+
loader.addURL(jarFile.toURI.toURL)
159+
Thread.currentThread().setContextClassLoader(loader)
160+
val excClass: Class[_] = Utils.classForName("repro.MyException")
161+
162+
// NOTE: we must run the cluster with "local" so that the executor can load the compiled
163+
// jar.
164+
sc = new SparkContext("local", "test", conf)
165+
val rdd = sc.parallelize(Seq(1), 1).map { _ =>
166+
val exc = excClass.newInstance().asInstanceOf[Exception]
167+
throw exc
168+
}
169+
170+
// the driver should not have any problems resolving the exception class and determining
171+
// why the task failed.
172+
val exceptionMessage = intercept[SparkException] {
173+
rdd.collect()
174+
}.getMessage
175+
176+
val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r
177+
val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r
178+
179+
assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined)
180+
assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty)
181+
} finally {
182+
Thread.currentThread.setContextClassLoader(originalClassLoader)
183+
}
184+
}
122185
}
123186

0 commit comments

Comments
 (0)