16
16
17
17
package androidx.compose.runtime
18
18
19
+ import androidx.collection.mutableObjectListOf
19
20
import androidx.compose.runtime.internal.AtomicInt
20
21
import androidx.compose.runtime.platform.makeSynchronizedObject
21
22
import androidx.compose.runtime.platform.synchronized
22
- import androidx.compose.runtime.snapshots.fastForEach
23
- import kotlin.coroutines.Continuation
23
+ import kotlin.contracts.ExperimentalContracts
24
+ import kotlin.contracts.InvocationKind
25
+ import kotlin.contracts.contract
24
26
import kotlin.coroutines.resumeWithException
27
+ import kotlin.jvm.JvmInline
28
+ import kotlinx.coroutines.CancellableContinuation
25
29
import kotlinx.coroutines.CancellationException
26
30
import kotlinx.coroutines.suspendCancellableCoroutine
27
31
@@ -38,24 +42,34 @@ import kotlinx.coroutines.suspendCancellableCoroutine
38
42
*/
39
43
class BroadcastFrameClock (private val onNewAwaiters : (() -> Unit )? = null ) : MonotonicFrameClock {
40
44
41
- private class FrameAwaiter <R >(val onFrame : (Long ) -> R , val continuation : Continuation <R >) {
45
+ private class FrameAwaiter <R >(onFrame : (Long ) -> R , continuation : CancellableContinuation <R >) {
46
+ private var onFrame: ((Long ) -> R )? = onFrame
47
+ private var continuation: (CancellableContinuation <R >)? = continuation
48
+
49
+ fun cancel () {
50
+ onFrame = null
51
+ continuation = null
52
+ }
53
+
42
54
fun resume (timeNanos : Long ) {
43
- continuation.resumeWith(runCatching { onFrame(timeNanos) })
55
+ val onFrame = onFrame ? : return
56
+ continuation?.resumeWith(runCatching { onFrame(timeNanos) })
57
+ }
58
+
59
+ fun resumeWithException (exception : Throwable ) {
60
+ continuation?.resumeWithException(exception)
44
61
}
45
62
}
46
63
47
64
private val lock = makeSynchronizedObject()
48
65
private var failureCause: Throwable ? = null
49
- private var awaiters = mutableListOf<FrameAwaiter <* >>()
50
- private var spareList = mutableListOf<FrameAwaiter <* >>()
51
-
52
- // Uses AtomicInt to avoid adding AtomicBoolean to the Expect/Actual requirements of the
53
- // runtime.
54
- private val hasAwaitersUnlocked = AtomicInt (0 )
66
+ private val pendingAwaitersCountUnlocked = AtomicAwaitersCount ()
67
+ private var awaiters = mutableObjectListOf<FrameAwaiter <* >>()
68
+ private var spareList = mutableObjectListOf<FrameAwaiter <* >>()
55
69
56
70
/* * `true` if there are any callers of [withFrameNanos] awaiting to run for a pending frame. */
57
71
val hasAwaiters: Boolean
58
- get() = hasAwaitersUnlocked.get() != 0
72
+ get() = pendingAwaitersCountUnlocked.hasAwaiters()
59
73
60
74
/* *
61
75
* Send a frame for time [timeNanos] to all current callers of [withFrameNanos]. The `onFrame`
@@ -69,7 +83,7 @@ class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : Mon
69
83
val toResume = awaiters
70
84
awaiters = spareList
71
85
spareList = toResume
72
- hasAwaitersUnlocked.set( 0 )
86
+ pendingAwaitersCountUnlocked.incrementVersionAndResetCount( )
73
87
74
88
for (i in 0 until toResume.size) {
75
89
toResume[i].resume(timeNanos)
@@ -81,24 +95,24 @@ class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : Mon
81
95
override suspend fun <R > withFrameNanos (onFrame : (Long ) -> R ): R =
82
96
suspendCancellableCoroutine { co ->
83
97
val awaiter = FrameAwaiter (onFrame, co)
84
- val hasNewAwaiters =
85
- synchronized(lock) {
86
- val cause = failureCause
87
- if (cause != null ) {
88
- co.resumeWithException(cause)
89
- return @suspendCancellableCoroutine
90
- }
91
- val hadAwaiters = awaiters.isNotEmpty()
92
- awaiters.add(awaiter)
93
- if (! hadAwaiters) hasAwaitersUnlocked.set(1 )
94
- ! hadAwaiters
98
+ var hasNewAwaiters = false
99
+ var awaitersVersion = - 1
100
+ synchronized(lock) {
101
+ val cause = failureCause
102
+ if (cause != null ) {
103
+ co.resumeWithException(cause)
104
+ return @suspendCancellableCoroutine
95
105
}
106
+ awaitersVersion =
107
+ pendingAwaitersCountUnlocked.incrementCountAndGetVersion(
108
+ ifFirstAwaiter = { hasNewAwaiters = true }
109
+ )
110
+ awaiters.add(awaiter)
111
+ }
96
112
97
113
co.invokeOnCancellation {
98
- synchronized(lock) {
99
- awaiters.remove(awaiter)
100
- if (awaiters.isEmpty()) hasAwaitersUnlocked.set(0 )
101
- }
114
+ awaiter.cancel()
115
+ pendingAwaitersCountUnlocked.decrementCount(awaitersVersion)
102
116
}
103
117
104
118
// Wake up anything that was waiting for someone to schedule a frame
@@ -118,9 +132,9 @@ class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : Mon
118
132
synchronized(lock) {
119
133
if (failureCause != null ) return
120
134
failureCause = cause
121
- awaiters.fastForEach { awaiter -> awaiter.continuation .resumeWithException(cause) }
135
+ awaiters.forEach { awaiter -> awaiter.resumeWithException(cause) }
122
136
awaiters.clear()
123
- hasAwaitersUnlocked.set( 0 )
137
+ pendingAwaitersCountUnlocked.incrementVersionAndResetCount( )
124
138
}
125
139
}
126
140
@@ -133,4 +147,84 @@ class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : Mon
133
147
) {
134
148
fail(cancellationException)
135
149
}
150
+
151
+ /* *
152
+ * [BroadcastFrameClock] tracks the number of pending [FrameAwaiter]s using this atomic type.
153
+ * This count is made up of two components: The count itself ([COUNT_BITS] bits) and a version
154
+ * ([VERSION_BITS] bits).
155
+ *
156
+ * The count is incremented when a new awaiter is added, and decremented when an awaiter is
157
+ * cancelled. When the pending awaiters are processed, this count is reset to zero. To prevent a
158
+ * race condition that can cause an inaccurate count when awaiters are removed, cancelled
159
+ * awaiters only decrement their count when the version of the counter has not changed. The
160
+ * version is incremented every time the awaiters are dispatched and the count resets to zero.
161
+ *
162
+ * The number of bits required to track the version is very small, and the version is allowed
163
+ * and expected to roll over. By allocating 4 bits for the version, cancellation events can be
164
+ * correctly counted as long as the cancellation callback completes within 16 [sendFrame]
165
+ * invocations. Most cancelled awaiters will invoke their cancellation logic almost immediately,
166
+ * so even a narrow version range can be highly effective.
167
+ */
168
+ @Suppress(" NOTHING_TO_INLINE" )
169
+ @JvmInline
170
+ private value class AtomicAwaitersCount private constructor(private val value : AtomicInt ) {
171
+ constructor () : this (AtomicInt (0 ))
172
+
173
+ inline fun hasAwaiters (): Boolean = value.get().count > 0
174
+
175
+ inline fun incrementVersionAndResetCount () {
176
+ update { pack(version = it.version + 1 , count = 0 ) }
177
+ }
178
+
179
+ @OptIn(ExperimentalContracts ::class )
180
+ inline fun incrementCountAndGetVersion (ifFirstAwaiter : () -> Unit ): Int {
181
+ contract { callsInPlace(ifFirstAwaiter, InvocationKind .AT_MOST_ONCE ) }
182
+ val newValue = update { it + 1 }
183
+ if (newValue.count == 1 ) ifFirstAwaiter()
184
+ return newValue.version
185
+ }
186
+
187
+ inline fun decrementCount (version : Int ) {
188
+ update { value -> if (value.version == version) value - 1 else value }
189
+ }
190
+
191
+ private inline fun update (calculation : (Int ) -> Int ): Int {
192
+ var oldValue: Int
193
+ var newValue: Int
194
+ do {
195
+ oldValue = value.get()
196
+ newValue = calculation(oldValue)
197
+ } while (! value.compareAndSet(oldValue, newValue))
198
+ return newValue
199
+ }
200
+
201
+ /* *
202
+ * Bitpacks [version] and [count] together. The topmost bit is always 0 to enforce this
203
+ * value always being positive. [version] takes the next [VERSION_BITS] topmost bits, and
204
+ * [count] takes the remaining [COUNT_BITS] bits.
205
+ *
206
+ * `| 0 | version | count |`
207
+ */
208
+ private fun pack (version : Int , count : Int ): Int {
209
+ val versionComponent = (version and (- 1 shl VERSION_BITS ).inv ()) shl COUNT_BITS
210
+ val countComponent = count and (- 1 shl COUNT_BITS ).inv ()
211
+ return versionComponent or countComponent
212
+ }
213
+
214
+ private inline val Int .version: Int
215
+ get() = (this ushr COUNT_BITS ) and (- 1 shl VERSION_BITS ).inv ()
216
+
217
+ private inline val Int .count: Int
218
+ get() = this and (- 1 shl COUNT_BITS ).inv ()
219
+
220
+ override fun toString (): String {
221
+ val current = value.get()
222
+ return " AtomicAwaitersCount(version = ${current.version} , count = ${current.count} )"
223
+ }
224
+
225
+ companion object {
226
+ private const val VERSION_BITS = 4
227
+ private const val COUNT_BITS = Int .SIZE_BITS - VERSION_BITS - 1
228
+ }
229
+ }
136
230
}
0 commit comments