Skip to content

Commit 336ca8c

Browse files
cloud-fanNgone51
authored andcommitted
[SPARK-49386][CORE][SQL][FOLLOWUP] More accurate memory tracking for memory based spill threshold
### What changes were proposed in this pull request? This is a followup of #47856 . It makes the memory tracking more accurate in several places: 1. In `ShuffleExternalSorter`/`UnsafeExternalSorter`, the memory is used by both the sorter itself, and its underlying in-memort sorter (for sorting shuffle partition ids). We need to add them up to calcuate the current memory usage. 2. In `ExternalAppendOnlyUnsafeRowArray`, the records are inserted to an in-memory buffer first. If the buffer gets too large (currently based on num records), we switch to `UnsafeExternalSorter`. The in-memory buffer also needs a memory based threshold ### Why are the changes needed? More accurate memory tracking results to better spill decisions ### Does this PR introduce _any_ user-facing change? No, the feature is not released yet. ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #52190 from cloud-fan/spill. Lead-authored-by: Wenchen Fan <cloud0fan@gmail.com> Co-authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Yi Wu <yi.wu@databricks.com>
1 parent 65c5775 commit 336ca8c

File tree

14 files changed

+142
-89
lines changed

14 files changed

+142
-89
lines changed

core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleCheck
9090
private final int numElementsForSpillThreshold;
9191

9292
/**
93-
* Force this sorter to spill when the size in memory is beyond this threshold.
93+
* Force this sorter to spill when the in memory size in bytes is beyond this threshold.
9494
*/
95-
private final long recordsSizeForSpillThreshold;
95+
private final long sizeInBytesForSpillThreshold;
9696

9797
/** The buffer size to use when writing spills using DiskBlockObjectWriter */
9898
private final int fileBufferSizeBytes;
@@ -117,7 +117,7 @@ final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleCheck
117117
@Nullable private ShuffleInMemorySorter inMemSorter;
118118
@Nullable private MemoryBlock currentPage = null;
119119
private long pageCursor = -1;
120-
private long inMemRecordsSize = 0;
120+
private long totalPageMemoryUsageBytes = 0;
121121

122122
// Checksum calculator for each partition. Empty when shuffle checksum disabled.
123123
private final Checksum[] partitionChecksums;
@@ -142,7 +142,7 @@ final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleCheck
142142
(int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
143143
this.numElementsForSpillThreshold =
144144
(int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD());
145-
this.recordsSizeForSpillThreshold =
145+
this.sizeInBytesForSpillThreshold =
146146
(long) conf.get(package$.MODULE$.SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD());
147147
this.writeMetrics = writeMetrics;
148148
this.inMemSorter = new ShuffleInMemorySorter(
@@ -314,11 +314,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException {
314314
}
315315

316316
private long getMemoryUsage() {
317-
long totalPageSize = 0;
318-
for (MemoryBlock page : allocatedPages) {
319-
totalPageSize += page.size();
320-
}
321-
return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
317+
return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageMemoryUsageBytes;
322318
}
323319

324320
private void updatePeakMemoryUsed() {
@@ -342,11 +338,11 @@ private long freeMemory() {
342338
for (MemoryBlock block : allocatedPages) {
343339
memoryFreed += block.size();
344340
freePage(block);
341+
totalPageMemoryUsageBytes -= block.size();
345342
}
346343
allocatedPages.clear();
347344
currentPage = null;
348345
pageCursor = 0;
349-
inMemRecordsSize = 0;
350346
return memoryFreed;
351347
}
352348

@@ -417,6 +413,7 @@ private void acquireNewPageIfNecessary(int required) {
417413
currentPage = allocatePage(required);
418414
pageCursor = currentPage.getBaseOffset();
419415
allocatedPages.add(currentPage);
416+
totalPageMemoryUsageBytes += currentPage.size();
420417
}
421418
}
422419

@@ -432,10 +429,17 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p
432429
MDC.of(LogKeys.NUM_ELEMENTS_SPILL_RECORDS, inMemSorter.numRecords()),
433430
MDC.of(LogKeys.NUM_ELEMENTS_SPILL_THRESHOLD, numElementsForSpillThreshold));
434431
spill();
435-
} else if (inMemRecordsSize >= recordsSizeForSpillThreshold) {
436-
logger.info("Spilling data because size of spilledRecords ({}) crossed the size threshold {}",
437-
MDC.of(LogKeys.SPILL_RECORDS_SIZE, inMemRecordsSize),
438-
MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD, recordsSizeForSpillThreshold));
432+
}
433+
434+
// TODO: Ideally we only need to check the spill threshold when new memory needs to be
435+
// allocated (both this sorter and the underlying ShuffleInMemorySorter may allocate
436+
// new memory), but it's simpler to check the total memory usage of these two sorters
437+
// before inserting each record.
438+
final long usedMemory = getMemoryUsage();
439+
if (usedMemory >= sizeInBytesForSpillThreshold) {
440+
logger.info("Spilling data because memory usage ({}) crossed the threshold {}",
441+
MDC.of(LogKeys.SPILL_RECORDS_SIZE, usedMemory),
442+
MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD, sizeInBytesForSpillThreshold));
439443
spill();
440444
}
441445

@@ -453,7 +457,6 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p
453457
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
454458
pageCursor += length;
455459
inMemSorter.insertRecord(recordAddress, partitionId);
456-
inMemRecordsSize += required;
457460
}
458461

459462
/**

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
8080
private final int numElementsForSpillThreshold;
8181

8282
/**
83-
* Force this sorter to spill when the size in memory is beyond this threshold.
83+
* Force this sorter to spill when the in memory size in bytes is beyond this threshold.
8484
*/
85-
private final long recordsSizeForSpillThreshold;
85+
private final long sizeInBytesForSpillThreshold;
8686

8787
/**
8888
* Memory pages that hold the records being sorted. The pages in this list are freed when
@@ -96,7 +96,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
9696

9797
// These variables are reset after spilling:
9898
@Nullable private volatile UnsafeInMemorySorter inMemSorter;
99-
private long inMemRecordsSize = 0;
99+
private long totalPageMemoryUsageBytes = 0;
100100

101101
private MemoryBlock currentPage = null;
102102
private long pageCursor = -1;
@@ -115,12 +115,12 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
115115
int initialSize,
116116
long pageSizeBytes,
117117
int numElementsForSpillThreshold,
118-
long recordsSizeForSpillThreshold,
118+
long sizeInBytesForSpillThreshold,
119119
UnsafeInMemorySorter inMemorySorter,
120120
long existingMemoryConsumption) throws IOException {
121121
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
122122
serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize,
123-
pageSizeBytes, numElementsForSpillThreshold, recordsSizeForSpillThreshold,
123+
pageSizeBytes, numElementsForSpillThreshold, sizeInBytesForSpillThreshold,
124124
inMemorySorter, false /* ignored */);
125125
sorter.spill(Long.MAX_VALUE, sorter);
126126
taskContext.taskMetrics().incMemoryBytesSpilled(existingMemoryConsumption);
@@ -140,11 +140,11 @@ public static UnsafeExternalSorter create(
140140
int initialSize,
141141
long pageSizeBytes,
142142
int numElementsForSpillThreshold,
143-
long recordsSizeForSpillThreshold,
143+
long sizeInBytesForSpillThreshold,
144144
boolean canUseRadixSort) {
145145
return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
146146
taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes,
147-
numElementsForSpillThreshold, recordsSizeForSpillThreshold, null, canUseRadixSort);
147+
numElementsForSpillThreshold, sizeInBytesForSpillThreshold, null, canUseRadixSort);
148148
}
149149

150150
private UnsafeExternalSorter(
@@ -157,7 +157,7 @@ private UnsafeExternalSorter(
157157
int initialSize,
158158
long pageSizeBytes,
159159
int numElementsForSpillThreshold,
160-
long recordsSizeForSpillThreshold,
160+
long sizeInBytesForSpillThreshold,
161161
@Nullable UnsafeInMemorySorter existingInMemorySorter,
162162
boolean canUseRadixSort) {
163163
super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode());
@@ -187,7 +187,7 @@ private UnsafeExternalSorter(
187187
this.inMemSorter = existingInMemorySorter;
188188
}
189189
this.peakMemoryUsedBytes = getMemoryUsage();
190-
this.recordsSizeForSpillThreshold = recordsSizeForSpillThreshold;
190+
this.sizeInBytesForSpillThreshold = sizeInBytesForSpillThreshold;
191191
this.numElementsForSpillThreshold = numElementsForSpillThreshold;
192192

193193
// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
@@ -248,7 +248,6 @@ public long spill(long size, MemoryConsumer trigger) throws IOException {
248248
// pages will currently be counted as memory spilled even though that space isn't actually
249249
// written to disk. This also counts the space needed to store the sorter's pointer array.
250250
inMemSorter.freeMemory();
251-
inMemRecordsSize = 0;
252251
// Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
253252
// records. Otherwise, if the task is over allocated memory, then without freeing the memory
254253
// pages, we might not be able to get memory for the pointer array.
@@ -264,11 +263,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException {
264263
* array.
265264
*/
266265
private long getMemoryUsage() {
267-
long totalPageSize = 0;
268-
for (MemoryBlock page : allocatedPages) {
269-
totalPageSize += page.size();
270-
}
271-
return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
266+
return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageMemoryUsageBytes;
272267
}
273268

274269
private void updatePeakMemoryUsed() {
@@ -320,6 +315,7 @@ private long freeMemory() {
320315
for (MemoryBlock block : pagesToFree) {
321316
memoryFreed += block.size();
322317
freePage(block);
318+
totalPageMemoryUsageBytes -= block.size();
323319
}
324320
return memoryFreed;
325321
}
@@ -378,6 +374,7 @@ public void cleanupResources() {
378374
} finally {
379375
for (MemoryBlock pageToFree : pagesToFree) {
380376
freePage(pageToFree);
377+
totalPageMemoryUsageBytes -= pageToFree.size();
381378
}
382379
if (inMemSorterToFree != null) {
383380
inMemSorterToFree.freeMemory();
@@ -448,6 +445,7 @@ private void acquireNewPageIfNecessary(int required) {
448445
currentPage = allocatePage(required);
449446
pageCursor = currentPage.getBaseOffset();
450447
allocatedPages.add(currentPage);
448+
totalPageMemoryUsageBytes += currentPage.size();
451449
}
452450
}
453451

@@ -495,10 +493,17 @@ public void insertRecord(
495493
MDC.of(LogKeys.NUM_ELEMENTS_SPILL_RECORDS, inMemSorter.numRecords()),
496494
MDC.of(LogKeys.NUM_ELEMENTS_SPILL_THRESHOLD, numElementsForSpillThreshold));
497495
spill();
498-
} else if (inMemRecordsSize >= recordsSizeForSpillThreshold) {
499-
logger.info("Spilling data because size of spilledRecords ({}) crossed the size threshold {}",
500-
MDC.of(LogKeys.SPILL_RECORDS_SIZE, inMemRecordsSize),
501-
MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD, recordsSizeForSpillThreshold));
496+
}
497+
498+
// TODO: Ideally we only need to check the spill threshold when new memory needs to be
499+
// allocated (both this sorter and the underlying UnsafeInMemorySorter may allocate
500+
// new memory), but it's simpler to check the total memory usage of these two sorters
501+
// before inserting each record.
502+
final long usedMemory = getMemoryUsage();
503+
if (usedMemory >= sizeInBytesForSpillThreshold) {
504+
logger.info("Spilling data because memory usage ({}) crossed the threshold {}",
505+
MDC.of(LogKeys.SPILL_RECORDS_SIZE, usedMemory),
506+
MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD, sizeInBytesForSpillThreshold));
502507
spill();
503508
}
504509

@@ -514,7 +519,6 @@ public void insertRecord(
514519
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
515520
pageCursor += length;
516521
inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
517-
inMemRecordsSize += required;
518522
}
519523

520524
/**

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,9 +1599,9 @@ package object config {
15991599
.createWithDefault(Integer.MAX_VALUE)
16001600

16011601
private[spark] val SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD =
1602-
ConfigBuilder("spark.shuffle.spill.maxRecordsSizeForSpillThreshold")
1602+
ConfigBuilder("spark.shuffle.spill.maxSizeInBytesForSpillThreshold")
16031603
.internal()
1604-
.doc("The maximum size in memory before forcing the shuffle sorter to spill. " +
1604+
.doc("The maximum in memory size in bytes before forcing the shuffle sorter to spill. " +
16051605
"By default it is Long.MAX_VALUE, which means we never force the sorter to spill, " +
16061606
"until we reach some limitations, like the max page size limitation for the pointer " +
16071607
"array in the sorter.")

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ public UnsafeKVExternalSorter(
6161
SerializerManager serializerManager,
6262
long pageSizeBytes,
6363
int numElementsForSpillThreshold,
64-
long maxRecordsSizeForSpillThreshold) throws IOException {
64+
long sizeInBytesForSpillThreshold) throws IOException {
6565
this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes,
66-
numElementsForSpillThreshold, maxRecordsSizeForSpillThreshold, null);
66+
numElementsForSpillThreshold, sizeInBytesForSpillThreshold, null);
6767
}
6868

6969
public UnsafeKVExternalSorter(
@@ -73,7 +73,7 @@ public UnsafeKVExternalSorter(
7373
SerializerManager serializerManager,
7474
long pageSizeBytes,
7575
int numElementsForSpillThreshold,
76-
long maxRecordsSizeForSpillThreshold,
76+
long sizeInBytesForSpillThreshold,
7777
@Nullable BytesToBytesMap map) throws IOException {
7878
this.keySchema = keySchema;
7979
this.valueSchema = valueSchema;
@@ -100,7 +100,7 @@ public UnsafeKVExternalSorter(
100100
(int) (long) SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()),
101101
pageSizeBytes,
102102
numElementsForSpillThreshold,
103-
maxRecordsSizeForSpillThreshold,
103+
sizeInBytesForSpillThreshold,
104104
canUseRadixSort);
105105
} else {
106106
// During spilling, the pointer array in `BytesToBytesMap` will not be used, so we can borrow
@@ -168,7 +168,7 @@ public UnsafeKVExternalSorter(
168168
(int) (long) SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()),
169169
pageSizeBytes,
170170
numElementsForSpillThreshold,
171-
maxRecordsSizeForSpillThreshold,
171+
sizeInBytesForSpillThreshold,
172172
inMemSorter,
173173
map.getTotalMemoryConsumption());
174174

0 commit comments

Comments
 (0)