Skip to content

Commit 124789b

Browse files
cloud-fanHyukjinKwon
authored andcommitted
[SPARK-25144][SQL][TEST][BRANCH-2.2] Free aggregate map when task ends
## What changes were proposed in this pull request? [SPARK-25144](https://issues.apache.org/jira/browse/SPARK-25144) reports memory leaks on Apache Spark 2.0.2 ~ 2.3.2-RC5. ```scala scala> case class Foo(bar: Option[String]) scala> val ds = List(Foo(Some("bar"))).toDS scala> val result = ds.flatMap(_.bar).distinct scala> result.rdd.isEmpty 18/08/19 23:01:54 WARN Executor: Managed memory leak detected; size = 8650752 bytes, TID = 125 res0: Boolean = false ``` This is a backport of cloud-fan 's #21738 which is a single commit among 3 commits of SPARK-21743. In addition, I added a test case to prevent regressions in branch-2.3 and branch-2.2. Although SPARK-21743 is reverted due to regression, this subpatch can go to branch-2.3 and branch-2.2. This will be merged as cloud-fan 's commit. ## How was this patch tested? Pass the jenkins with a newly added test case. Closes #22156 from dongjoon-hyun/SPARK-25144-2.2. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: hyukjinkwon <gurwls223@apache.org>
1 parent 6641aa6 commit 124789b

File tree

5 files changed

+35
-15
lines changed

5 files changed

+35
-15
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import java.io.IOException;
2121

2222
import org.apache.spark.SparkEnv;
23-
import org.apache.spark.memory.TaskMemoryManager;
23+
import org.apache.spark.TaskContext;
2424
import org.apache.spark.sql.catalyst.InternalRow;
2525
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
2626
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -84,7 +84,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) {
8484
* @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
8585
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
8686
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
87-
* @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
87+
* @param taskContext the current task context.
8888
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
8989
* @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
9090
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
@@ -93,21 +93,28 @@ public UnsafeFixedWidthAggregationMap(
9393
InternalRow emptyAggregationBuffer,
9494
StructType aggregationBufferSchema,
9595
StructType groupingKeySchema,
96-
TaskMemoryManager taskMemoryManager,
96+
TaskContext taskContext,
9797
int initialCapacity,
9898
long pageSizeBytes,
9999
boolean enablePerfMetrics) {
100100
this.aggregationBufferSchema = aggregationBufferSchema;
101101
this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length());
102102
this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
103103
this.groupingKeySchema = groupingKeySchema;
104-
this.map =
105-
new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
104+
this.map = new BytesToBytesMap(
105+
taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, enablePerfMetrics);
106106
this.enablePerfMetrics = enablePerfMetrics;
107107

108108
// Initialize the buffer for aggregation value
109109
final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
110110
this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
111+
112+
// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
113+
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
114+
// does not fully consume the aggregation map's output (e.g. aggregate followed by limit).
115+
taskContext.addTaskCompletionListener(context -> {
116+
free();
117+
});
111118
}
112119

113120
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ case class HashAggregateExec(
310310
initialBuffer,
311311
bufferSchema,
312312
groupingKeySchema,
313-
TaskContext.get().taskMemoryManager(),
313+
TaskContext.get(),
314314
1024 * 16, // initial capacity
315315
TaskContext.get().taskMemoryManager().pageSizeBytes,
316316
false // disable tracking of performance metrics

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class TungstenAggregationIterator(
160160
initialAggregationBuffer,
161161
StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)),
162162
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
163-
TaskContext.get().taskMemoryManager(),
163+
TaskContext.get(),
164164
1024 * 16, // initial capacity
165165
TaskContext.get().taskMemoryManager().pageSizeBytes,
166166
false // disable tracking of performance metrics

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,4 +2702,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
27022702
}
27032703
}
27042704
}
2705+
2706+
test("SPARK-25144 'distinct' causes memory leak") {
2707+
val ds = List(Foo(Some("bar"))).toDS
2708+
val result = ds.flatMap(_.bar).distinct
2709+
result.rdd.isEmpty
2710+
}
27052711
}
2712+
2713+
case class Foo(bar: Option[String])

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.collection.mutable
2323
import scala.util.{Random, Try}
2424
import scala.util.control.NonFatal
2525

26+
import org.mockito.Mockito._
2627
import org.scalatest.Matchers
2728

2829
import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl}
@@ -53,6 +54,8 @@ class UnsafeFixedWidthAggregationMapSuite
5354
private var memoryManager: TestMemoryManager = null
5455
private var taskMemoryManager: TaskMemoryManager = null
5556

57+
private var taskContext: TaskContext = null
58+
5659
def testWithMemoryLeakDetection(name: String)(f: => Unit) {
5760
def cleanup(): Unit = {
5861
if (taskMemoryManager != null) {
@@ -66,6 +69,8 @@ class UnsafeFixedWidthAggregationMapSuite
6669
val conf = new SparkConf().set("spark.memory.offHeap.enabled", "false")
6770
memoryManager = new TestMemoryManager(conf)
6871
taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
72+
taskContext = mock(classOf[TaskContext])
73+
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
6974

7075
TaskContext.setTaskContext(new TaskContextImpl(
7176
stageId = 0,
@@ -110,7 +115,7 @@ class UnsafeFixedWidthAggregationMapSuite
110115
emptyAggregationBuffer,
111116
aggBufferSchema,
112117
groupKeySchema,
113-
taskMemoryManager,
118+
taskContext,
114119
1024, // initial capacity,
115120
PAGE_SIZE_BYTES,
116121
false // disable perf metrics
@@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite
124129
emptyAggregationBuffer,
125130
aggBufferSchema,
126131
groupKeySchema,
127-
taskMemoryManager,
132+
taskContext,
128133
1024, // initial capacity
129134
PAGE_SIZE_BYTES,
130135
false // disable perf metrics
@@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite
151156
emptyAggregationBuffer,
152157
aggBufferSchema,
153158
groupKeySchema,
154-
taskMemoryManager,
159+
taskContext,
155160
128, // initial capacity
156161
PAGE_SIZE_BYTES,
157162
false // disable perf metrics
@@ -177,7 +182,7 @@ class UnsafeFixedWidthAggregationMapSuite
177182
emptyAggregationBuffer,
178183
aggBufferSchema,
179184
groupKeySchema,
180-
taskMemoryManager,
185+
taskContext,
181186
128, // initial capacity
182187
PAGE_SIZE_BYTES,
183188
false // disable perf metrics
@@ -225,7 +230,7 @@ class UnsafeFixedWidthAggregationMapSuite
225230
emptyAggregationBuffer,
226231
aggBufferSchema,
227232
groupKeySchema,
228-
taskMemoryManager,
233+
taskContext,
229234
128, // initial capacity
230235
PAGE_SIZE_BYTES,
231236
false // disable perf metrics
@@ -266,7 +271,7 @@ class UnsafeFixedWidthAggregationMapSuite
266271
emptyAggregationBuffer,
267272
StructType(Nil),
268273
StructType(Nil),
269-
taskMemoryManager,
274+
taskContext,
270275
128, // initial capacity
271276
PAGE_SIZE_BYTES,
272277
false // disable perf metrics
@@ -311,7 +316,7 @@ class UnsafeFixedWidthAggregationMapSuite
311316
emptyAggregationBuffer,
312317
aggBufferSchema,
313318
groupKeySchema,
314-
taskMemoryManager,
319+
taskContext,
315320
128, // initial capacity
316321
pageSize,
317322
false // disable perf metrics
@@ -349,7 +354,7 @@ class UnsafeFixedWidthAggregationMapSuite
349354
emptyAggregationBuffer,
350355
aggBufferSchema,
351356
groupKeySchema,
352-
taskMemoryManager,
357+
taskContext,
353358
128, // initial capacity
354359
pageSize,
355360
false // disable perf metrics

0 commit comments

Comments
 (0)