Skip to content

[SPARK-6209] Clean up connections in ExecutorClassLoader after failing to load classes (branch-1.2) #5174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions repl/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@
<artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Expand Down
92 changes: 75 additions & 17 deletions repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

package org.apache.spark.repl

import java.io.{ByteArrayOutputStream, InputStream}
import java.net.{URI, URL, URLEncoder}
import java.util.concurrent.{Executors, ExecutorService}
import java.io.{IOException, ByteArrayOutputStream, InputStream}
import java.net.{HttpURLConnection, URI, URL, URLEncoder}

import scala.util.control.NonFatal

import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.Utils
import org.apache.spark.util.ParentClassLoader
Expand All @@ -37,12 +38,15 @@ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
* Allows the user to specify if user class path should be first
*/
class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader,
userClassPathFirst: Boolean) extends ClassLoader {
userClassPathFirst: Boolean) extends ClassLoader with Logging {
val uri = new URI(classUri)
val directory = uri.getPath

val parentLoader = new ParentClassLoader(parent)

// Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes
private[repl] var httpUrlConnectionTimeoutMillis: Int = -1

// Hadoop FileSystem object for our URI, if it isn't using HTTP
var fileSystem: FileSystem = {
if (uri.getScheme() == "http") {
Expand Down Expand Up @@ -71,27 +75,81 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
}
}

private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = {
val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
newuri.toURL
} else {
new URL(classUri + "/" + urlEncode(pathInDirectory))
}
val connection: HttpURLConnection = url.openConnection().asInstanceOf[HttpURLConnection]
// Set the connection timeouts (for testing purposes)
if (httpUrlConnectionTimeoutMillis != -1) {
connection.setConnectTimeout(httpUrlConnectionTimeoutMillis)
connection.setReadTimeout(httpUrlConnectionTimeoutMillis)
}
connection.connect()
try {
if (connection.getResponseCode != 200) {
// Close the error stream so that the connection is eligible for re-use
try {
connection.getErrorStream.close()
} catch {
case ioe: IOException =>
logError("Exception while closing error stream", ioe)
}
throw new ClassNotFoundException(s"Class file not found at URL $url")
} else {
connection.getInputStream
}
} catch {
case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] =>
connection.disconnect()
throw e
}
}

private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = {
val path = new Path(directory, pathInDirectory)
if (fileSystem.exists(path)) {
fileSystem.open(path)
} else {
throw new ClassNotFoundException(s"Class file not found at path $path")
}
}

def findClassLocally(name: String): Option[Class[_]] = {
val pathInDirectory = name.replace('.', '/') + ".class"
var inputStream: InputStream = null
try {
val pathInDirectory = name.replace('.', '/') + ".class"
val inputStream = {
inputStream = {
if (fileSystem != null) {
fileSystem.open(new Path(directory, pathInDirectory))
getClassFileInputStreamFromFileSystem(pathInDirectory)
} else {
if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
newuri.toURL().openStream()
} else {
new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream()
}
getClassFileInputStreamFromHttpServer(pathInDirectory)
}
}
val bytes = readAndTransformClass(name, inputStream)
inputStream.close()
Some(defineClass(name, bytes, 0, bytes.length))
} catch {
case e: Exception => None
case e: ClassNotFoundException =>
// We did not find the class
logDebug(s"Did not load class $name from REPL class server at $uri", e)
None
case e: Exception =>
// Something bad happened while checking if the class exists
logError(s"Failed to check existence of class $name on REPL class server at $uri", e)
None
} finally {
if (inputStream != null) {
try {
inputStream.close()
} catch {
case e: Exception =>
logError("Exception while closing inputStream", e)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,33 @@ package org.apache.spark.repl
import java.io.File
import java.net.{URL, URLClassLoader}

import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.language.postfixOps

import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.scalatest.concurrent.Interruptor
import org.scalatest.concurrent.Timeouts._
import org.scalatest.mock.MockitoSugar
import org.mockito.Mockito._

import org.apache.spark.{SparkConf, TestUtils}
import org.apache.spark._
import org.apache.spark.util.Utils

class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
class ExecutorClassLoaderSuite
extends FunSuite
with BeforeAndAfterAll
with MockitoSugar
with Logging {

val childClassNames = List("ReplFakeClass1", "ReplFakeClass2")
val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3")
var tempDir1: File = _
var tempDir2: File = _
var url1: String = _
var urls2: Array[URL] = _
var classServer: HttpServer = _

override def beforeAll() {
super.beforeAll()
Expand All @@ -47,8 +60,12 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {

override def afterAll() {
super.afterAll()
if (classServer != null) {
classServer.stop()
}
Utils.deleteRecursively(tempDir1)
Utils.deleteRecursively(tempDir2)
SparkEnv.set(null)
}

test("child first") {
Expand Down Expand Up @@ -83,4 +100,53 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
}
}

test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") {
// This is a regression test for SPARK-6209, a bug where each failed attempt to load a class
// from the driver's class server would leak a HTTP connection, causing the class server's
// thread / connection pool to be exhausted.
val conf = new SparkConf()
val securityManager = new SecurityManager(conf)
classServer = new HttpServer(conf, tempDir1, securityManager)
classServer.start()
// ExecutorClassLoader uses SparkEnv's SecurityManager, so we need to mock this
val mockEnv = mock[SparkEnv]
when(mockEnv.securityManager).thenReturn(securityManager)
SparkEnv.set(mockEnv)
// Create an ExecutorClassLoader that's configured to load classes from the HTTP server
val parentLoader = new URLClassLoader(Array.empty, null)
val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false)
classLoader.httpUrlConnectionTimeoutMillis = 500
// Check that this class loader can actually load classes that exist
val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance()
val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "1")
// Try to perform a full GC now, since GC during the test might mask resource leaks
System.gc()
// When the original bug occurs, the test thread becomes blocked in a classloading call
// and does not respond to interrupts. Therefore, use a custom ScalaTest interruptor to
// shut down the HTTP server when the test times out
val interruptor: Interruptor = new Interruptor {
override def apply(thread: Thread): Unit = {
classServer.stop()
classServer = null
thread.interrupt()
}
}
def tryAndFailToLoadABunchOfClasses(): Unit = {
// The number of trials here should be much larger than Jetty's thread / connection limit
// in order to expose thread or connection leaks
for (i <- 1 to 1000) {
if (Thread.currentThread().isInterrupted) {
throw new InterruptedException()
}
// Incorporate the iteration number into the class name in order to avoid any response
// caching that might be added in the future
intercept[ClassNotFoundException] {
classLoader.loadClass(s"ReplFakeClassDoesNotExist$i").newInstance()
}
}
}
failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor)
}

}