Skip to content

Commit e265c60

Browse files
JacobZheng0927Mridul Muralidharan
authored and
Mridul Muralidharan
committed
[SPARK-47910][CORE] close stream when DiskBlockObjectWriter closeResources to avoid memory leak
### What changes were proposed in this pull request? close stream when DiskBlockObjectWriter closeResources to avoid memory leak ### Why are the changes needed? [SPARK-34647](https://issues.apache.org/jira/browse/SPARK-34647) replaced the ZstdInputStream with ZstdInputStreamNoFinalizer. This meant that all usages of CompressionCodec.compressedOutputStream would need to manually close the stream as this would no longer be handled by the finalizer mechanism. When using zstd for shuffle write compression, if for some reason the execution of this process is interrupted(eg. enable spark.sql.execution.interruptOnCancel and cancel Job). The memory used by `ZstdInputStreamNoFinalizer` may not be freed, causing a memory leak. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? #### Spark Shell Configuration ``` $> export SPARK_SUBMIT_OPTS="-XX:+AlwaysPreTouch -Xms1g" $> $SPARK_HOME/bin/spark-shell --conf spark.io.compression.codec=zstd ``` #### Test Script ```scala import java.util.concurrent.TimeUnit import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global import scala.util.Random sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true) (1 to 50).foreach { batch => { val jobA = Future { val df1 = spark.range(2000000).map { _ =>(Random.nextString(20),Random.nextInt(1000),Random.nextInt(1000),Random.nextInt(10))}.toDF("a","b","c","d") val df2 = spark.range(2000000).map { _ =>(Random.nextString(20),Random.nextInt(1000),Random.nextInt(1000),Random.nextInt(10))}.toDF("a","b","c","d") df1.join(df2,"b").show() } Thread.sleep(5000) sc.cancelJobGroup("jobA") }} ``` #### Memory Monitor ``` $> while true; do echo \"$(date +%Y-%m-%d' '%H:%M:%S)\",$(pmap -x <PID> | grep "total kB" | awk '{print $4}'); sleep 10; done; ``` #### Results ##### Before ``` "2024-05-13 16:54:23",1332384 "2024-05-13 16:54:33",1417112 "2024-05-13 16:54:43",2211684 "2024-05-13 16:54:53",3060820 "2024-05-13 16:55:03",3850444 "2024-05-13 16:55:14",4631744 "2024-05-13 16:55:24",5317200 "2024-05-13 16:55:34",6019464 "2024-05-13 16:55:44",6489180 "2024-05-13 16:55:54",7255548 "2024-05-13 16:56:05",7718728 "2024-05-13 16:56:15",8388392 "2024-05-13 16:56:25",8927636 "2024-05-13 16:56:36",9473412 "2024-05-13 16:56:46",10000380 "2024-05-13 16:56:56",10344024 "2024-05-13 16:57:07",10734204 "2024-05-13 16:57:17",11211900 "2024-05-13 16:57:27",11665524 "2024-05-13 16:57:38",12268976 "2024-05-13 16:57:48",12896264 "2024-05-13 16:57:58",13572244 "2024-05-13 16:58:09",14252416 "2024-05-13 16:58:19",14915560 "2024-05-13 16:58:30",15484196 "2024-05-13 16:58:40",16170324 ``` ##### After ``` "2024-05-13 16:35:44",1355428 "2024-05-13 16:35:54",1391028 "2024-05-13 16:36:04",1673720 "2024-05-13 16:36:14",2103716 "2024-05-13 16:36:24",2129876 "2024-05-13 16:36:35",2166412 "2024-05-13 16:36:45",2177672 "2024-05-13 16:36:55",2188340 "2024-05-13 16:37:05",2190688 "2024-05-13 16:37:15",2195168 "2024-05-13 16:37:26",2199296 "2024-05-13 16:37:36",2228052 "2024-05-13 16:37:46",2238104 "2024-05-13 16:37:56",2260624 "2024-05-13 16:38:06",2307184 "2024-05-13 16:38:16",2331140 "2024-05-13 16:38:27",2323388 "2024-05-13 16:38:37",2357552 "2024-05-13 16:38:47",2352948 "2024-05-13 16:38:57",2364744 "2024-05-13 16:39:07",2368528 "2024-05-13 16:39:18",2385492 "2024-05-13 16:39:28",2389184 "2024-05-13 16:39:38",2388060 "2024-05-13 16:39:48",2388336 "2024-05-13 16:39:58",2386916 ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #46131 from JacobZheng0927/zstdMemoryLeak. Authored-by: JacobZheng0927 <zsh517559523@163.com> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
1 parent f0b7cfa commit e265c60

File tree

2 files changed

+114
-17
lines changed

2 files changed

+114
-17
lines changed

core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ private[spark] class DiskBlockObjectWriter(
126126
*/
127127
private var numRecordsCommitted = 0L
128128

129+
// For testing only.
130+
private[storage] def getSerializerWrappedStream: OutputStream = bs
131+
132+
// For testing only.
133+
private[storage] def getSerializationStream: SerializationStream = objOut
134+
129135
/**
130136
* Set the checksum that the checksumOutputStream should use
131137
*/
@@ -174,19 +180,36 @@ private[spark] class DiskBlockObjectWriter(
174180
* Should call after committing or reverting partial writes.
175181
*/
176182
private def closeResources(): Unit = {
177-
if (initialized) {
178-
Utils.tryWithSafeFinally {
179-
mcs.manualClose()
180-
} {
181-
channel = null
182-
mcs = null
183-
bs = null
184-
fos = null
185-
ts = null
186-
objOut = null
187-
initialized = false
188-
streamOpen = false
189-
hasBeenClosed = true
183+
try {
184+
if (streamOpen) {
185+
Utils.tryWithSafeFinally {
186+
if (null != objOut) objOut.close()
187+
bs = null
188+
} {
189+
objOut = null
190+
if (null != bs) bs.close()
191+
bs = null
192+
}
193+
}
194+
} catch {
195+
case e: IOException =>
196+
logInfo(log"Exception occurred while closing the output stream" +
197+
log"${MDC(ERROR, e.getMessage)}")
198+
} finally {
199+
if (initialized) {
200+
Utils.tryWithSafeFinally {
201+
mcs.manualClose()
202+
} {
203+
channel = null
204+
mcs = null
205+
bs = null
206+
fos = null
207+
ts = null
208+
objOut = null
209+
initialized = false
210+
streamOpen = false
211+
hasBeenClosed = true
212+
}
190213
}
191214
}
192215
}

core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
*/
1717
package org.apache.spark.storage
1818

19-
import java.io.File
19+
import java.io.{File, InputStream, OutputStream}
20+
import java.nio.ByteBuffer
21+
22+
import scala.reflect.ClassTag
2023

2124
import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
2225
import org.apache.spark.executor.ShuffleWriteMetrics
23-
import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
26+
import org.apache.spark.serializer.{DeserializationStream, JavaSerializer, SerializationStream, Serializer, SerializerInstance, SerializerManager}
2427
import org.apache.spark.util.Utils
2528

2629
class DiskBlockObjectWriterSuite extends SparkFunSuite {
@@ -43,10 +46,14 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite {
4346
private def createWriter(): (DiskBlockObjectWriter, File, ShuffleWriteMetrics) = {
4447
val file = new File(tempDir, "somefile")
4548
val conf = new SparkConf()
46-
val serializerManager = new SerializerManager(new JavaSerializer(conf), conf)
49+
val serializerManager = new CustomSerializerManager(new JavaSerializer(conf), conf, None)
4750
val writeMetrics = new ShuffleWriteMetrics()
4851
val writer = new DiskBlockObjectWriter(
49-
file, serializerManager, new JavaSerializer(new SparkConf()).newInstance(), 1024, true,
52+
file,
53+
serializerManager,
54+
new CustomJavaSerializer(new SparkConf()).newInstance(),
55+
1024,
56+
true,
5057
writeMetrics)
5158
(writer, file, writeMetrics)
5259
}
@@ -196,9 +203,76 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite {
196203
for (i <- 1 to 500) {
197204
writer.write(i, i)
198205
}
206+
207+
val bs = writer.getSerializerWrappedStream.asInstanceOf[OutputStreamWithCloseDetecting]
208+
val objOut = writer.getSerializationStream.asInstanceOf[SerializationStreamWithCloseDetecting]
209+
199210
writer.closeAndDelete()
200211
assert(!file.exists())
201212
assert(writeMetrics.bytesWritten == 0)
202213
assert(writeMetrics.recordsWritten == 0)
214+
assert(bs.isClosed)
215+
assert(objOut.isClosed)
216+
}
217+
}
218+
219+
trait CloseDetecting {
220+
var isClosed = false
221+
}
222+
223+
class OutputStreamWithCloseDetecting(outputStream: OutputStream)
224+
extends OutputStream
225+
with CloseDetecting {
226+
override def write(b: Int): Unit = outputStream.write(b)
227+
228+
override def close(): Unit = {
229+
isClosed = true
230+
outputStream.close()
231+
}
232+
}
233+
234+
class CustomSerializerManager(
235+
defaultSerializer: Serializer,
236+
conf: SparkConf,
237+
encryptionKey: Option[Array[Byte]])
238+
extends SerializerManager(defaultSerializer, conf, encryptionKey) {
239+
override def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = {
240+
new OutputStreamWithCloseDetecting(wrapForCompression(blockId, wrapForEncryption(s)))
241+
}
242+
}
243+
244+
class CustomJavaSerializer(conf: SparkConf) extends JavaSerializer(conf) {
245+
246+
override def newInstance(): SerializerInstance = {
247+
new CustomJavaSerializerInstance(super.newInstance())
203248
}
204249
}
250+
251+
class SerializationStreamWithCloseDetecting(serializationStream: SerializationStream)
252+
extends SerializationStream with CloseDetecting {
253+
254+
override def close(): Unit = {
255+
isClosed = true
256+
serializationStream.close()
257+
}
258+
259+
override def writeObject[T: ClassTag](t: T): SerializationStream =
260+
serializationStream.writeObject(t)
261+
262+
override def flush(): Unit = serializationStream.flush()
263+
}
264+
265+
class CustomJavaSerializerInstance(instance: SerializerInstance) extends SerializerInstance {
266+
override def serializeStream(s: OutputStream): SerializationStream =
267+
new SerializationStreamWithCloseDetecting(instance.serializeStream(s))
268+
269+
override def serialize[T: ClassTag](t: T): ByteBuffer = instance.serialize(t)
270+
271+
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = instance.deserialize(bytes)
272+
273+
override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
274+
instance.deserialize(bytes, loader)
275+
276+
override def deserializeStream(s: InputStream): DeserializationStream =
277+
instance.deserializeStream(s)
278+
}

0 commit comments

Comments
 (0)