Skip to content

[SPARK-6209] Clean up connections in ExecutorClassLoader after failing to load classes (master branch PR) #4944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions repl/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@
<artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<scope>test</scope>
</dependency>

<!-- Explicit listing of transitive deps that are shaded. Otherwise, odd compiler crashes. -->
<dependency>
Expand Down
85 changes: 67 additions & 18 deletions repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.repl

import java.io.{ByteArrayOutputStream, InputStream, FileNotFoundException}
import java.net.{URI, URL, URLEncoder}
import java.util.concurrent.{Executors, ExecutorService}
import java.io.{IOException, ByteArrayOutputStream, InputStream}
import java.net.{HttpURLConnection, URI, URL, URLEncoder}

import scala.util.control.NonFatal

import org.apache.hadoop.fs.{FileSystem, Path}

Expand All @@ -43,6 +44,9 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader

val parentLoader = new ParentClassLoader(parent)

// Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes
private[repl] var httpUrlConnectionTimeoutMillis: Int = -1

// Hadoop FileSystem object for our URI, if it isn't using HTTP
var fileSystem: FileSystem = {
if (Set("http", "https", "ftp").contains(uri.getScheme)) {
Expand Down Expand Up @@ -71,37 +75,82 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
}
}

private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = {
val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
newuri.toURL
} else {
new URL(classUri + "/" + urlEncode(pathInDirectory))
}
val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(),
SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection]
// Set the connection timeouts (for testing purposes)
if (httpUrlConnectionTimeoutMillis != -1) {
connection.setConnectTimeout(httpUrlConnectionTimeoutMillis)
connection.setReadTimeout(httpUrlConnectionTimeoutMillis)
}
connection.connect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not too familiar with this API - but can this connection object ever leak even if the returned input-stream is closed? Or, can an exception occur before returning from this method and the connection leak that way? I'm guessing these aren't possible but just wanted to educate myself as to how these scenarios won't happen =)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the HttpURLConnection Javadoc:

Each HttpURLConnection instance is used to make a single request but the underlying network connection to the HTTP server may be transparently shared by other instances. Calling the close() methods on the InputStream or OutputStream of an HttpURLConnection after a request may free network resources associated with this instance but has no effect on any shared persistent connection. Calling the disconnect() method may close the underlying socket if a persistent connection is otherwise idle at that time.

This can cause problems in practice: https://scotte.github.io/2015/01/httpurlconnection-socket-leak/.

Actually, that blog post reminds me that I should probably call getErrorStream and read that stream in the next error-handling block prior to calling disconnect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On closer inspection, it looks like you need to consume the errorStream if you want the underlying connection to be eligible for re-use, but we may not necessarily care about that here: http://docs.oracle.com/javase/6/docs/technotes/guides/net/http-keepalive.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about if an exception throws before we return the input stream? In that case the input stream is not closed but perhaps the connection is still open?

I'm basically asking, should we have a catch block for Throwable in this method after we instantiate the connection to close the connection if we run into an Exception before returning the input stream?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the connect() call is unlikely to fail in ways that leak resources, but suppose it wouldn't hurt to add another try-catch block just to be safe. I can do this tomorrow.

try {
if (connection.getResponseCode != 200) {
// Close the error stream so that the connection is eligible for re-use
try {
connection.getErrorStream.close()
} catch {
case ioe: IOException =>
logError("Exception while closing error stream", ioe)
}
throw new ClassNotFoundException(s"Class file not found at URL $url")
} else {
connection.getInputStream
}
} catch {
case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] =>
connection.disconnect()
throw e
}
}

private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = {
val path = new Path(directory, pathInDirectory)
if (fileSystem.exists(path)) {
fileSystem.open(path)
} else {
throw new ClassNotFoundException(s"Class file not found at path $path")
}
}

def findClassLocally(name: String): Option[Class[_]] = {
val pathInDirectory = name.replace('.', '/') + ".class"
var inputStream: InputStream = null
try {
val pathInDirectory = name.replace('.', '/') + ".class"
val inputStream = {
inputStream = {
if (fileSystem != null) {
fileSystem.open(new Path(directory, pathInDirectory))
getClassFileInputStreamFromFileSystem(pathInDirectory)
} else {
val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
newuri.toURL
} else {
new URL(classUri + "/" + urlEncode(pathInDirectory))
}

Utils.setupSecureURLConnection(url.openConnection(), SparkEnv.get.securityManager)
.getInputStream
getClassFileInputStreamFromHttpServer(pathInDirectory)
}
}
val bytes = readAndTransformClass(name, inputStream)
inputStream.close()
Some(defineClass(name, bytes, 0, bytes.length))
} catch {
case e: FileNotFoundException =>
case e: ClassNotFoundException =>
// We did not find the class
logDebug(s"Did not load class $name from REPL class server at $uri", e)
None
case e: Exception =>
// Something bad happened while checking if the class exists
logError(s"Failed to check existence of class $name on REPL class server at $uri", e)
None
} finally {
if (inputStream != null) {
try {
inputStream.close()
} catch {
case e: Exception =>
logError("Exception while closing inputStream", e)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,33 @@ package org.apache.spark.repl
import java.io.File
import java.net.{URL, URLClassLoader}

import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.language.postfixOps

import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.scalatest.concurrent.Interruptor
import org.scalatest.concurrent.Timeouts._
import org.scalatest.mock.MockitoSugar
import org.mockito.Mockito._

import org.apache.spark.{SparkConf, TestUtils}
import org.apache.spark._
import org.apache.spark.util.Utils

class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
class ExecutorClassLoaderSuite
extends FunSuite
with BeforeAndAfterAll
with MockitoSugar
with Logging {

val childClassNames = List("ReplFakeClass1", "ReplFakeClass2")
val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3")
var tempDir1: File = _
var tempDir2: File = _
var url1: String = _
var urls2: Array[URL] = _
var classServer: HttpServer = _

override def beforeAll() {
super.beforeAll()
Expand All @@ -47,8 +60,12 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {

override def afterAll() {
super.afterAll()
if (classServer != null) {
classServer.stop()
}
Utils.deleteRecursively(tempDir1)
Utils.deleteRecursively(tempDir2)
SparkEnv.set(null)
}

test("child first") {
Expand Down Expand Up @@ -83,4 +100,53 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
}
}

test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") {
// This is a regression test for SPARK-6209, a bug where each failed attempt to load a class
// from the driver's class server would leak a HTTP connection, causing the class server's
// thread / connection pool to be exhausted.
val conf = new SparkConf()
val securityManager = new SecurityManager(conf)
classServer = new HttpServer(conf, tempDir1, securityManager)
classServer.start()
// ExecutorClassLoader uses SparkEnv's SecurityManager, so we need to mock this
val mockEnv = mock[SparkEnv]
when(mockEnv.securityManager).thenReturn(securityManager)
SparkEnv.set(mockEnv)
// Create an ExecutorClassLoader that's configured to load classes from the HTTP server
val parentLoader = new URLClassLoader(Array.empty, null)
val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false)
classLoader.httpUrlConnectionTimeoutMillis = 500
// Check that this class loader can actually load classes that exist
val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance()
val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "1")
// Try to perform a full GC now, since GC during the test might mask resource leaks
System.gc()
// When the original bug occurs, the test thread becomes blocked in a classloading call
// and does not respond to interrupts. Therefore, use a custom ScalaTest interruptor to
// shut down the HTTP server when the test times out
val interruptor: Interruptor = new Interruptor {
override def apply(thread: Thread): Unit = {
classServer.stop()
classServer = null
thread.interrupt()
}
}
def tryAndFailToLoadABunchOfClasses(): Unit = {
// The number of trials here should be much larger than Jetty's thread / connection limit
// in order to expose thread or connection leaks
for (i <- 1 to 1000) {
if (Thread.currentThread().isInterrupted) {
throw new InterruptedException()
}
// Incorporate the iteration number into the class name in order to avoid any response
// caching that might be added in the future
intercept[ClassNotFoundException] {
classLoader.loadClass(s"ReplFakeClassDoesNotExist$i").newInstance()
}
}
}
failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor)
}

}