Skip to content

Commit 7b57a83

Browse files
andrewbaileyigordmn
authored andcommitted
Cancel BroadcastFrameClock awaiters without locks
In the referenced bug, there's a deadlock where a call to `withFrameNanos` is being canceled on one thread while another thread is dispatching another thread. To avoid the deadlock, this commit updates the BroadcastFrameClock awaiter so that it's possible to cancel an awaiter without acquiring any locks. Fixes: b/407027032 Test: BroadcastFrameClockTest.locklessCancellation Relnote: "Fixed a deadlock that may affect Molecule users when a suspended call to `FrameClock.withFrameNanos` is cancelled while a frame is being dispatched." Change-Id: I89cab8e3eab14ed9a85b36e151f11b5f526a01fd
1 parent 7fcfc94 commit 7b57a83

File tree

2 files changed

+168
-36
lines changed

2 files changed

+168
-36
lines changed

compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/BroadcastFrameClock.kt

Lines changed: 123 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616

1717
package androidx.compose.runtime
1818

19+
import androidx.collection.mutableObjectListOf
1920
import androidx.compose.runtime.internal.AtomicInt
2021
import androidx.compose.runtime.platform.makeSynchronizedObject
2122
import androidx.compose.runtime.platform.synchronized
22-
import androidx.compose.runtime.snapshots.fastForEach
23-
import kotlin.coroutines.Continuation
23+
import kotlin.contracts.ExperimentalContracts
24+
import kotlin.contracts.InvocationKind
25+
import kotlin.contracts.contract
2426
import kotlin.coroutines.resumeWithException
27+
import kotlin.jvm.JvmInline
28+
import kotlinx.coroutines.CancellableContinuation
2529
import kotlinx.coroutines.CancellationException
2630
import kotlinx.coroutines.suspendCancellableCoroutine
2731

@@ -38,24 +42,34 @@ import kotlinx.coroutines.suspendCancellableCoroutine
3842
*/
3943
class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : MonotonicFrameClock {
4044

41-
private class FrameAwaiter<R>(val onFrame: (Long) -> R, val continuation: Continuation<R>) {
45+
private class FrameAwaiter<R>(onFrame: (Long) -> R, continuation: CancellableContinuation<R>) {
46+
private var onFrame: ((Long) -> R)? = onFrame
47+
private var continuation: (CancellableContinuation<R>)? = continuation
48+
49+
fun cancel() {
50+
onFrame = null
51+
continuation = null
52+
}
53+
4254
fun resume(timeNanos: Long) {
43-
continuation.resumeWith(runCatching { onFrame(timeNanos) })
55+
val onFrame = onFrame ?: return
56+
continuation?.resumeWith(runCatching { onFrame(timeNanos) })
57+
}
58+
59+
fun resumeWithException(exception: Throwable) {
60+
continuation?.resumeWithException(exception)
4461
}
4562
}
4663

4764
private val lock = makeSynchronizedObject()
4865
private var failureCause: Throwable? = null
49-
private var awaiters = mutableListOf<FrameAwaiter<*>>()
50-
private var spareList = mutableListOf<FrameAwaiter<*>>()
51-
52-
// Uses AtomicInt to avoid adding AtomicBoolean to the Expect/Actual requirements of the
53-
// runtime.
54-
private val hasAwaitersUnlocked = AtomicInt(0)
66+
private val pendingAwaitersCountUnlocked = AtomicAwaitersCount()
67+
private var awaiters = mutableObjectListOf<FrameAwaiter<*>>()
68+
private var spareList = mutableObjectListOf<FrameAwaiter<*>>()
5569

5670
/** `true` if there are any callers of [withFrameNanos] awaiting to run for a pending frame. */
5771
val hasAwaiters: Boolean
58-
get() = hasAwaitersUnlocked.get() != 0
72+
get() = pendingAwaitersCountUnlocked.hasAwaiters()
5973

6074
/**
6175
* Send a frame for time [timeNanos] to all current callers of [withFrameNanos]. The `onFrame`
@@ -69,7 +83,7 @@ class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : Mon
6983
val toResume = awaiters
7084
awaiters = spareList
7185
spareList = toResume
72-
hasAwaitersUnlocked.set(0)
86+
pendingAwaitersCountUnlocked.incrementVersionAndResetCount()
7387

7488
for (i in 0 until toResume.size) {
7589
toResume[i].resume(timeNanos)
@@ -81,24 +95,24 @@ class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : Mon
8195
override suspend fun <R> withFrameNanos(onFrame: (Long) -> R): R =
8296
suspendCancellableCoroutine { co ->
8397
val awaiter = FrameAwaiter(onFrame, co)
84-
val hasNewAwaiters =
85-
synchronized(lock) {
86-
val cause = failureCause
87-
if (cause != null) {
88-
co.resumeWithException(cause)
89-
return@suspendCancellableCoroutine
90-
}
91-
val hadAwaiters = awaiters.isNotEmpty()
92-
awaiters.add(awaiter)
93-
if (!hadAwaiters) hasAwaitersUnlocked.set(1)
94-
!hadAwaiters
98+
var hasNewAwaiters = false
99+
var awaitersVersion = -1
100+
synchronized(lock) {
101+
val cause = failureCause
102+
if (cause != null) {
103+
co.resumeWithException(cause)
104+
return@suspendCancellableCoroutine
95105
}
106+
awaitersVersion =
107+
pendingAwaitersCountUnlocked.incrementCountAndGetVersion(
108+
ifFirstAwaiter = { hasNewAwaiters = true }
109+
)
110+
awaiters.add(awaiter)
111+
}
96112

97113
co.invokeOnCancellation {
98-
synchronized(lock) {
99-
awaiters.remove(awaiter)
100-
if (awaiters.isEmpty()) hasAwaitersUnlocked.set(0)
101-
}
114+
awaiter.cancel()
115+
pendingAwaitersCountUnlocked.decrementCount(awaitersVersion)
102116
}
103117

104118
// Wake up anything that was waiting for someone to schedule a frame
@@ -118,9 +132,9 @@ class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : Mon
118132
synchronized(lock) {
119133
if (failureCause != null) return
120134
failureCause = cause
121-
awaiters.fastForEach { awaiter -> awaiter.continuation.resumeWithException(cause) }
135+
awaiters.forEach { awaiter -> awaiter.resumeWithException(cause) }
122136
awaiters.clear()
123-
hasAwaitersUnlocked.set(0)
137+
pendingAwaitersCountUnlocked.incrementVersionAndResetCount()
124138
}
125139
}
126140

@@ -133,4 +147,84 @@ class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : Mon
133147
) {
134148
fail(cancellationException)
135149
}
150+
151+
/**
152+
* [BroadcastFrameClock] tracks the number of pending [FrameAwaiter]s using this atomic type.
153+
* This count is made up of two components: The count itself ([COUNT_BITS] bits) and a version
154+
* ([VERSION_BITS] bits).
155+
*
156+
* The count is incremented when a new awaiter is added, and decremented when an awaiter is
157+
* cancelled. When the pending awaiters are processed, this count is reset to zero. To prevent a
158+
* race condition that can cause an inaccurate count when awaiters are removed, cancelled
159+
* awaiters only decrement their count when the version of the counter has not changed. The
160+
* version is incremented every time the awaiters are dispatched and the count resets to zero.
161+
*
162+
* The number of bits required to track the version is very small, and the version is allowed
163+
* and expected to roll over. By allocating 4 bits for the version, cancellation events can be
164+
* correctly counted as long as the cancellation callback completes within 16 [sendFrame]
165+
* invocations. Most cancelled awaiters will invoke their cancellation logic almost immediately,
166+
* so even a narrow version range can be highly effective.
167+
*/
168+
@Suppress("NOTHING_TO_INLINE")
169+
@JvmInline
170+
private value class AtomicAwaitersCount private constructor(private val value: AtomicInt) {
171+
constructor() : this(AtomicInt(0))
172+
173+
inline fun hasAwaiters(): Boolean = value.get().count > 0
174+
175+
inline fun incrementVersionAndResetCount() {
176+
update { pack(version = it.version + 1, count = 0) }
177+
}
178+
179+
@OptIn(ExperimentalContracts::class)
180+
inline fun incrementCountAndGetVersion(ifFirstAwaiter: () -> Unit): Int {
181+
contract { callsInPlace(ifFirstAwaiter, InvocationKind.AT_MOST_ONCE) }
182+
val newValue = update { it + 1 }
183+
if (newValue.count == 1) ifFirstAwaiter()
184+
return newValue.version
185+
}
186+
187+
inline fun decrementCount(version: Int) {
188+
update { value -> if (value.version == version) value - 1 else value }
189+
}
190+
191+
private inline fun update(calculation: (Int) -> Int): Int {
192+
var oldValue: Int
193+
var newValue: Int
194+
do {
195+
oldValue = value.get()
196+
newValue = calculation(oldValue)
197+
} while (!value.compareAndSet(oldValue, newValue))
198+
return newValue
199+
}
200+
201+
/**
202+
* Bitpacks [version] and [count] together. The topmost bit is always 0 to enforce this
203+
* value always being positive. [version] takes the next [VERSION_BITS] topmost bits, and
204+
* [count] takes the remaining [COUNT_BITS] bits.
205+
*
206+
* `| 0 | version | count |`
207+
*/
208+
private fun pack(version: Int, count: Int): Int {
209+
val versionComponent = (version and (-1 shl VERSION_BITS).inv()) shl COUNT_BITS
210+
val countComponent = count and (-1 shl COUNT_BITS).inv()
211+
return versionComponent or countComponent
212+
}
213+
214+
private inline val Int.version: Int
215+
get() = (this ushr COUNT_BITS) and (-1 shl VERSION_BITS).inv()
216+
217+
private inline val Int.count: Int
218+
get() = this and (-1 shl COUNT_BITS).inv()
219+
220+
override fun toString(): String {
221+
val current = value.get()
222+
return "AtomicAwaitersCount(version = ${current.version}, count = ${current.count})"
223+
}
224+
225+
companion object {
226+
private const val VERSION_BITS = 4
227+
private const val COUNT_BITS = Int.SIZE_BITS - VERSION_BITS - 1
228+
}
229+
}
136230
}

compose/runtime/runtime/src/nonEmulatorCommonTest/kotlin/androidx/compose/runtime/BroadcastFrameClockTest.kt

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,32 @@
1414
* limitations under the License.
1515
*/
1616

17-
package androidx.compose.runtime.dispatch
17+
package androidx.compose.runtime
1818

19+
import androidx.compose.runtime.internal.AtomicInt
1920
import kotlin.test.Test
2021
import kotlin.test.assertEquals
22+
import kotlin.test.assertFalse
2123
import kotlin.test.assertTrue
2224
import kotlinx.coroutines.CancellationException
25+
import kotlinx.coroutines.CoroutineStart.UNDISPATCHED
2326
import kotlinx.coroutines.Deferred
27+
import kotlinx.coroutines.Dispatchers
2428
import kotlinx.coroutines.ExperimentalCoroutinesApi
29+
import kotlinx.coroutines.InternalCoroutinesApi
2530
import kotlinx.coroutines.async
31+
import kotlinx.coroutines.cancelAndJoin
32+
import kotlinx.coroutines.launch
2633
import kotlinx.coroutines.test.UnconfinedTestDispatcher
2734
import kotlinx.coroutines.test.runTest
35+
import kotlinx.coroutines.yield
2836

2937
@ExperimentalCoroutinesApi
3038
class BroadcastFrameClockTest {
3139
@Test
3240
fun sendAndReceiveFrames() =
3341
runTest(UnconfinedTestDispatcher()) {
34-
val clock = androidx.compose.runtime.BroadcastFrameClock()
42+
val clock = BroadcastFrameClock()
3543

3644
val frameAwaiter = async { clock.withFrameNanos { it } }
3745

@@ -49,7 +57,7 @@ class BroadcastFrameClockTest {
4957
@Test
5058
fun cancelClock() =
5159
runTest(UnconfinedTestDispatcher()) {
52-
val clock = androidx.compose.runtime.BroadcastFrameClock()
60+
val clock = BroadcastFrameClock()
5361
val frameAwaiter = async { clock.withFrameNanos { it } }
5462

5563
clock.cancel()
@@ -66,15 +74,45 @@ class BroadcastFrameClockTest {
6674
@Test
6775
fun failClockWhenNewAwaitersNotified() =
6876
runTest(UnconfinedTestDispatcher()) {
69-
val clock =
70-
androidx.compose.runtime.BroadcastFrameClock {
71-
throw CancellationException("failed frame clock")
72-
}
77+
val clock = BroadcastFrameClock { throw CancellationException("failed frame clock") }
7378

7479
val failingAwaiter = async { clock.withFrameNanos { it } }
7580
assertAwaiterCancelled("failingAwaiter", failingAwaiter)
7681

7782
val lateAwaiter = async { clock.withFrameNanos { it } }
7883
assertAwaiterCancelled("lateAwaiter", lateAwaiter)
7984
}
85+
86+
@OptIn(InternalCoroutinesApi::class)
87+
@Test(timeout = 5_000)
88+
fun locklessCancellation() = runTest {
89+
val clock = BroadcastFrameClock()
90+
val cancellationGate = AtomicInt(1)
91+
92+
var spin = true
93+
async(start = UNDISPATCHED) {
94+
clock.withFrameNanos {
95+
cancellationGate.add(-1)
96+
@Suppress("BanThreadSleep") while (spin) Thread.sleep(100)
97+
}
98+
}
99+
100+
val cancellingJob = async(start = UNDISPATCHED) { clock.withFrameNanos {} }
101+
102+
launch(Dispatchers.Default) { clock.sendFrame(1) }
103+
104+
// Wait for the spinlock to start
105+
while (cancellationGate.get() != 0) yield()
106+
107+
// Assert that this line doesn't deadlock.
108+
cancellingJob.cancelAndJoin()
109+
110+
// Make sure that we can queue up new jobs for subsequent frames
111+
spin = false
112+
assertFalse(clock.hasAwaiters)
113+
async(start = UNDISPATCHED) { clock.withFrameNanos {} }
114+
assertTrue(clock.hasAwaiters)
115+
116+
clock.cancel()
117+
}
80118
}

0 commit comments

Comments
 (0)