File tree Expand file tree Collapse file tree 2 files changed +19
-3
lines changed
main/scala/org/apache/spark/shuffle
test/scala/org/apache/spark/shuffle Expand file tree Collapse file tree 2 files changed +19
-3
lines changed Original file line number Diff line number Diff line change @@ -66,8 +66,9 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
66
66
val curMem = threadMemory(threadId)
67
67
val freeMemory = maxMemory - threadMemory.values.sum
68
68
69
- // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads
70
- val maxToGrant = math.min(numBytes, (maxMemory / numActiveThreads) - curMem)
69
+ // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads;
70
+ // don't let it be negative
71
+ val maxToGrant = math.min(numBytes, math.max(0 , (maxMemory / numActiveThreads) - curMem))
71
72
72
73
if (curMem < maxMemory / (2 * numActiveThreads)) {
73
74
// We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking;
Original file line number Diff line number Diff line change @@ -159,7 +159,7 @@ class ShuffleMemoryManagerSuite extends FunSuite with Timeouts {
159
159
160
160
test(" threads can block to get at least 1 / 2N memory" ) {
161
161
// t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
162
- // for a bit and releases 250 bytes, which should then be greanted to t2. Further requests
162
+ // for a bit and releases 250 bytes, which should then be granted to t2. Further requests
163
163
// by t2 will return false right away because it now has 1 / 2N of the memory.
164
164
165
165
val manager = new ShuffleMemoryManager (1000L )
@@ -291,4 +291,19 @@ class ShuffleMemoryManagerSuite extends FunSuite with Timeouts {
291
291
assert(state.t2WaitTime > 200 , s " t2 waited less than 200 ms ( ${state.t2WaitTime}) " )
292
292
}
293
293
}
294
+
295
+ test(" threads should not be granted a negative size" ) {
296
+ val manager = new ShuffleMemoryManager (1000L )
297
+ manager.tryToAcquire(700L )
298
+
299
+ val latch = new CountDownLatch (1 )
300
+ startThread(" t1" ) {
301
+ manager.tryToAcquire(300L )
302
+ latch.countDown()
303
+ }
304
+ latch.await() // Wait until `t1` calls `tryToAcquire`
305
+
306
+ val granted = manager.tryToAcquire(300L )
307
+ assert(0 === granted, " granted is negative" )
308
+ }
294
309
}
You can’t perform that action at this time.
0 commit comments