Skip to content

Commit b5b37c6

Browse files
committed
changed LimitingFlowCollector.requests to AtomicLong
This avoids Int overflow when client is misbehaving and is sending multiple RequestN frames with n=Int.MAX_VALUE This closes #213
1 parent 8efcd11 commit b5b37c6

File tree

3 files changed

+362
-10
lines changed

3 files changed

+362
-10
lines changed

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,62 @@ internal suspend inline fun Flow<Payload>.collectLimiting(limiter: Limiter, cros
3232
}
3333
}
3434

35-
//TODO revisit 2 atomics and sync object
35+
/**
36+
* Maintains the amount of requests which the client is ready to consume and
37+
* prevents sending further updates by suspending the sending coroutine
38+
* if this amount reaches 0.
39+
*
40+
* ### Operation
41+
*
42+
* Each [useRequest] call decrements the maintained requests amount.
43+
* Calling coroutine is suspended when this amount reaches 0.
44+
* The coroutine is resumed when [updateRequests] is called.
45+
*
46+
* ### Unbounded mode
47+
*
48+
* Limiter enters an unbounded mode when:
49+
* * [Limiter] is created passing `Int.MAX_VALUE` as `initial`
50+
* * client sends a `RequestN` frame with `Int.MAX_VALUE`
51+
* * Internal Long counter overflows
52+
*
53+
* In unbounded mode Limiter will assume that the client
54+
* is able to process requests without limitations, all further
55+
* [updateRequests] will be NOP and [useRequest] will never suspend.
56+
*/
3657
internal class Limiter(initial: Int) : SynchronizedObject() {
37-
private val requests = atomic(initial)
38-
private val awaiter = atomic<CancellableContinuation<Unit>?>(null)
58+
private val requests: AtomicLong = atomic(initial.toLong())
59+
private val unbounded: AtomicBoolean = atomic(initial == Int.MAX_VALUE)
60+
private var awaiter: CancellableContinuation<Unit>? = null
3961

4062
fun updateRequests(n: Int) {
41-
if (n <= 0) return
63+
if (n <= 0 || unbounded.value) return
4264
synchronized(this) {
43-
requests += n
44-
awaiter.getAndSet(null)?.takeIf(CancellableContinuation<Unit>::isActive)?.resume(Unit)
65+
val updatedRequests = requests.value + n.toLong()
66+
if (updatedRequests < 0) {
67+
unbounded.value = true
68+
requests.value = Long.MAX_VALUE
69+
} else {
70+
requests.value = updatedRequests
71+
}
72+
73+
if (awaiter?.isActive == true) {
74+
awaiter?.resume(Unit)
75+
awaiter = null
76+
}
4577
}
4678
}
4779

4880
suspend fun useRequest() {
49-
if (requests.getAndDecrement() > 0) {
81+
if (unbounded.value || requests.decrementAndGet() >= 0) {
5082
currentCoroutineContext().ensureActive()
5183
} else {
52-
suspendCancellableCoroutine<Unit> {
84+
suspendCancellableCoroutine<Unit> { continuation ->
5385
synchronized(this) {
54-
awaiter.value = it
55-
if (requests.value >= 0 && it.isActive) it.resume(Unit)
86+
if (requests.value >= 0 && continuation.isActive) {
87+
continuation.resume(Unit)
88+
} else {
89+
this.awaiter = continuation
90+
}
5691
}
5792
}
5893
}

rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import io.rsocket.kotlin.keepalive.*
2323
import io.rsocket.kotlin.payload.*
2424
import io.rsocket.kotlin.test.*
2525
import io.rsocket.kotlin.transport.local.*
26+
import kotlinx.atomicfu.atomic
2627
import kotlinx.coroutines.*
2728
import kotlinx.coroutines.channels.*
2829
import kotlinx.coroutines.flow.*
@@ -192,6 +193,52 @@ class RSocketTest : SuspendTest, TestWithLeakCheck {
192193
assertTrue(channel.receiveCatching().isClosed)
193194
}
194195

196+
@Test
197+
fun testStreamInitialUnbounded() = test {
198+
val requester = start(RSocketRequestHandler {
199+
requestStream {
200+
(0..9).asFlow().map {
201+
payload(it.toString())
202+
}
203+
}
204+
})
205+
requester.requestStream(payload("HELLO"))
206+
.flowOn(PrefetchStrategy(Int.MAX_VALUE, 0))
207+
.test {
208+
repeat(10) {
209+
awaitItem().close()
210+
}
211+
awaitComplete()
212+
}
213+
}
214+
215+
@Test
216+
fun testStreamRequestNUnbounded() = test {
217+
class UnboundedAfterNStrategy(private val initial: Int) : RequestStrategy {
218+
override fun provide(): RequestStrategy.Element = Element()
219+
inner class Element : RequestStrategy.Element {
220+
private val requested = atomic(initial)
221+
override suspend fun firstRequest(): Int = initial
222+
override suspend fun nextRequest(): Int {
223+
val requestUnbounded = requested.getAndDecrement() == 0
224+
return if (requestUnbounded) Int.MAX_VALUE else 0
225+
}
226+
}
227+
}
228+
229+
start(RSocketRequestHandler {
230+
requestStream {
231+
(0..9).asFlow().map { payload(it.toString()) }
232+
}
233+
})
234+
.requestStream(payload("HELLO"))
235+
.flowOn(UnboundedAfterNStrategy(initial = 5))
236+
.test {
237+
repeat(10) { awaitItem().close() }
238+
awaitComplete()
239+
}
240+
}
241+
195242
@Test
196243
fun testChannel() = test {
197244
val awaiter = Job()

0 commit comments

Comments
 (0)