Skip to content

[SPARK-17666] Ensure that RecordReaders are closed by data source file scans #15245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat

import org.apache.spark.TaskContext
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.util.MLUtils
Expand Down Expand Up @@ -159,8 +160,10 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

(file: PartitionedFile) => {
val points =
new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))

val points = linesReader
.map(_.toString.trim)
.filterNot(line => line.isEmpty || line.startsWith("#"))
.map { line =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.datasources

import java.io.Closeable
import java.net.URI

import org.apache.hadoop.conf.Configuration
Expand All @@ -30,7 +31,8 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
* An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines
* in that file.
*/
class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends Iterator[Text] {
class HadoopFileLinesReader(
file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable {
private val iterator = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a RecordReaderIterator, whose memory footprint should become practically nothing after close() is called, hence my decision to not null things out here.

val fileSplit = new FileSplit(
new Path(new URI(file.filePath)),
Expand All @@ -48,4 +50,6 @@ class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends
override def hasNext: Boolean = iterator.hasNext

override def next(): Text = iterator.next()

override def close(): Unit = iterator.close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.datasources

import java.io.Closeable

import org.apache.hadoop.mapreduce.RecordReader

import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -27,7 +29,8 @@ import org.apache.spark.sql.catalyst.InternalRow
* Note that this returns [[Object]]s instead of [[InternalRow]] because we rely on erasure to pass
* column batches by pretending they are rows.
*/
class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] {
class RecordReaderIterator[T](
private[this] var rowReader: RecordReader[_, T]) extends Iterator[T] with Closeable {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By nulling out the rowReader I think that this will prevent memory consumption from becoming too high in the list of task completion callbacks.

private[this] var havePair = false
private[this] var finished = false

Expand All @@ -38,7 +41,7 @@ class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T]
// Close and release the reader here; close() will also be called when the task
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is CompletionIterator-style cleanup.

// completes, but for tasks that read from many files, it helps to release the
// resources early.
rowReader.close()
close()
}
havePair = !finished
}
Expand All @@ -52,4 +55,18 @@ class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T]
havePair = false
rowReader.getCurrentValue
}

override def close(): Unit = {
if (rowReader != null) {
try {
// Close the reader and release it. Note: it's very important that we don't close the
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is copied from NewHadoopRdd, which contains similar defensive programming.

// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
// older Hadoop 2.x releases. That bug can lead to non-deterministic corruption issues
// when reading compressed input.
rowReader.close()
} finally {
rowReader = null
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce._

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -112,7 +113,9 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
(file: PartitionedFile) => {
val lineIterator = {
val conf = broadcastedHadoopConf.value.value
new HadoopFileLinesReader(file, conf).map { line =>
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
linesReader.map { line =>
new String(line.getBytes, 0, line.getLength, csvOptions.charset)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
Expand Down Expand Up @@ -104,7 +105,9 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)

(file: PartitionedFile) => {
val lines = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map(_.toString)
val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
val lines = linesReader.map(_.toString)
val parser = new JacksonParser(requiredSchema, columnNameOfCorruptRecord, parsedOptions)
lines.flatMap(parser.parse)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.parquet.hadoop.util.ContextUtil
import org.apache.parquet.schema.MessageType
import org.slf4j.bridge.SLF4JBridgeHandler

import org.apache.spark.SparkException
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -388,6 +388,7 @@ class ParquetFileFormat
}

val iter = new RecordReaderIterator(parquetReader)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))

// UnsafeRowParquetRecordReader appends the columns internally to avoid another copy.
if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat

import org.apache.spark.TaskContext
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
Expand Down Expand Up @@ -100,6 +101,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {

(file: PartitionedFile) => {
val reader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close()))

if (requiredSchema.isEmpty) {
val emptyUnsafeRow = new UnsafeRow(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, Re
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}

import org.apache.spark.TaskContext
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -146,12 +147,15 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength)
}

val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close()))

// Unwraps `OrcStruct`s to `UnsafeRow`s
OrcRelation.unwrapOrcStructs(
conf,
requiredSchema,
Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]),
new RecordReaderIterator[OrcStruct](orcRecordReader))
recordsIterator)
}
}
}
Expand Down