Skip to content

[SPARK-17740] Spark tests should mock / interpose HDFS to ensure that streams are closed #15306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions core/src/test/scala/org/apache/spark/DebugFilesystem.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import java.io.{FileDescriptor, InputStream}
import java.lang
import java.nio.ByteBuffer
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.hadoop.fs._

import org.apache.spark.internal.Logging

object DebugFilesystem extends Logging {
// Stores the set of active streams and their creation sites.
private val openStreams = new ConcurrentHashMap[FSDataInputStream, Throwable]()

def clearOpenStreams(): Unit = {
openStreams.clear()
}

def assertNoOpenStreams(): Unit = {
val numOpen = openStreams.size()
if (numOpen > 0) {
for (exc <- openStreams.values().asScala) {
logWarning("Leaked filesystem connection created at:")
exc.printStackTrace()
}
throw new RuntimeException(s"There are $numOpen possibly leaked file streams.")
}
}
}

/**
* DebugFilesystem wraps file open calls to track all open connections. This can be used in tests
* to check that connections are not leaked.
*/
// TODO(ekl) we should consider always interposing this to expose num open conns as a metric
class DebugFilesystem extends LocalFileSystem {
import DebugFilesystem._

override def open(f: Path, bufferSize: Int): FSDataInputStream = {
val wrapped: FSDataInputStream = super.open(f, bufferSize)
openStreams.put(wrapped, new Throwable())

new FSDataInputStream(wrapped.getWrappedStream) {
override def setDropBehind(dropBehind: lang.Boolean): Unit = wrapped.setDropBehind(dropBehind)

override def getWrappedStream: InputStream = wrapped.getWrappedStream

override def getFileDescriptor: FileDescriptor = wrapped.getFileDescriptor

override def getPos: Long = wrapped.getPos

override def seekToNewSource(targetPos: Long): Boolean = wrapped.seekToNewSource(targetPos)

override def seek(desired: Long): Unit = wrapped.seek(desired)

override def setReadahead(readahead: lang.Long): Unit = wrapped.setReadahead(readahead)

override def read(position: Long, buffer: Array[Byte], offset: Int, length: Int): Int =
wrapped.read(position, buffer, offset, length)

override def read(buf: ByteBuffer): Int = wrapped.read(buf)

override def readFully(position: Long, buffer: Array[Byte], offset: Int, length: Int): Unit =
wrapped.readFully(position, buffer, offset, length)

override def readFully(position: Long, buffer: Array[Byte]): Unit =
wrapped.readFully(position, buffer)

override def available(): Int = wrapped.available()

override def mark(readlimit: Int): Unit = wrapped.mark(readlimit)

override def skip(n: Long): Long = wrapped.skip(n)

override def markSupported(): Boolean = wrapped.markSupported()

override def close(): Unit = {
wrapped.close()
openStreams.remove(wrapped)
}

override def read(): Int = wrapped.read()

override def reset(): Unit = wrapped.reset()

override def toString: String = wrapped.toString

override def equals(obj: scala.Any): Boolean = wrapped.equals(obj)

override def hashCode(): Int = wrapped.hashCode()
}
}
}
17 changes: 14 additions & 3 deletions core/src/test/scala/org/apache/spark/SharedSparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark

import org.scalatest.BeforeAndAfterAll
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.scalatest.Suite

/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */
trait SharedSparkContext extends BeforeAndAfterAll { self: Suite =>
trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { self: Suite =>

@transient private var _sc: SparkContext = _

Expand All @@ -31,7 +31,8 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite =>

override def beforeAll() {
super.beforeAll()
_sc = new SparkContext("local[4]", "test", conf)
_sc = new SparkContext(
"local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName))
}

override def afterAll() {
Expand All @@ -42,4 +43,14 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite =>
super.afterAll()
}
}

protected override def beforeEach(): Unit = {
super.beforeEach()
DebugFilesystem.clearOpenStreams()
}

protected override def afterEach(): Unit = {
super.afterEach()
DebugFilesystem.assertNoOpenStreams()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex
assert(column.getUTF8String(3 * i + 1).toString == i.toString)
assert(column.getUTF8String(3 * i + 2).toString == i.toString)
}
reader.close()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,14 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
}

// Open and delete
fm.open(path)
val f1 = fm.open(path)
fm.delete(path)
assert(!fm.exists(path))
intercept[IOException] {
fm.open(path)
}
fm.delete(path) // should not throw exception
f1.close()

// Rename
val path1 = new Path(s"$dir/file1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

package org.apache.spark.sql.test

import org.apache.spark.SparkConf
import org.scalatest.BeforeAndAfterEach

import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.sql.{SparkSession, SQLContext}


/**
* Helper trait for SQL test suites where all tests share a single [[TestSparkSession]].
*/
trait SharedSQLContext extends SQLTestUtils {
trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach {

protected val sparkConf = new SparkConf()

Expand Down Expand Up @@ -52,7 +54,8 @@ trait SharedSQLContext extends SQLTestUtils {
protected override def beforeAll(): Unit = {
SparkSession.sqlListener.set(null)
if (_spark == null) {
_spark = new TestSparkSession(sparkConf)
_spark = new TestSparkSession(
sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName))
}
// Ensure we have initialized the context before calling parent code
super.beforeAll()
Expand All @@ -71,4 +74,14 @@ trait SharedSQLContext extends SQLTestUtils {
super.afterAll()
}
}

protected override def beforeEach(): Unit = {
super.beforeEach()
DebugFilesystem.clearOpenStreams()
}

protected override def afterEach(): Unit = {
super.afterEach()
DebugFilesystem.assertNoOpenStreams()
}
}