Skip to content

Commit aa95bd3

Browse files
authored
changed Limiter.requests to AtomicLong (#214)
This avoids Int overflow when client is misbehaving and is sending multiple RequestN frames with n=Int.MAX_VALUE
1 parent 8efcd11 commit aa95bd3

File tree

3 files changed

+316
-9
lines changed

3 files changed

+316
-9
lines changed

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,50 @@ internal suspend inline fun Flow<Payload>.collectLimiting(limiter: Limiter, cros
3232
}
3333
}
3434

35-
//TODO revisit 2 atomics and sync object
35+
/**
36+
* Maintains the amount of requests which the client is ready to consume and
37+
* prevents sending further updates by suspending the sending coroutine
38+
* if this amount reaches 0.
39+
*
40+
* ### Operation
41+
*
42+
* Each [useRequest] call decrements the maintained requests amount.
43+
* Calling coroutine is suspended when this amount reaches 0.
44+
* The coroutine is resumed when [updateRequests] is called.
45+
*
46+
*/
3647
internal class Limiter(initial: Int) : SynchronizedObject() {
37-
private val requests = atomic(initial)
38-
private val awaiter = atomic<CancellableContinuation<Unit>?>(null)
48+
private val requests: AtomicLong = atomic(initial.toLong())
49+
private var awaiter: CancellableContinuation<Unit>? = null
3950

4051
fun updateRequests(n: Int) {
4152
if (n <= 0) return
4253
synchronized(this) {
43-
requests += n
44-
awaiter.getAndSet(null)?.takeIf(CancellableContinuation<Unit>::isActive)?.resume(Unit)
54+
val updatedRequests = requests.value + n.toLong()
55+
if (updatedRequests < 0) {
56+
requests.value = Long.MAX_VALUE
57+
} else {
58+
requests.value = updatedRequests
59+
}
60+
61+
if (awaiter?.isActive == true) {
62+
awaiter?.resume(Unit)
63+
awaiter = null
64+
}
4565
}
4666
}
4767

4868
suspend fun useRequest() {
49-
if (requests.getAndDecrement() > 0) {
69+
if (requests.decrementAndGet() >= 0) {
5070
currentCoroutineContext().ensureActive()
5171
} else {
52-
suspendCancellableCoroutine<Unit> {
72+
suspendCancellableCoroutine<Unit> { continuation ->
5373
synchronized(this) {
54-
awaiter.value = it
55-
if (requests.value >= 0 && it.isActive) it.resume(Unit)
74+
if (requests.value >= 0 && continuation.isActive) {
75+
continuation.resume(Unit)
76+
} else {
77+
this.awaiter = continuation
78+
}
5679
}
5780
}
5881
}

rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import io.rsocket.kotlin.keepalive.*
2323
import io.rsocket.kotlin.payload.*
2424
import io.rsocket.kotlin.test.*
2525
import io.rsocket.kotlin.transport.local.*
26+
import kotlinx.atomicfu.atomic
2627
import kotlinx.coroutines.*
2728
import kotlinx.coroutines.channels.*
2829
import kotlinx.coroutines.flow.*
@@ -192,6 +193,40 @@ class RSocketTest : SuspendTest, TestWithLeakCheck {
192193
assertTrue(channel.receiveCatching().isClosed)
193194
}
194195

196+
@Test
197+
fun testStreamInitialMaxValue() = test {
198+
val requester = start(RSocketRequestHandler {
199+
requestStream {
200+
(0..9).asFlow().map {
201+
payload(it.toString())
202+
}
203+
}
204+
})
205+
requester.requestStream(payload("HELLO"))
206+
.flowOn(PrefetchStrategy(Int.MAX_VALUE, 0))
207+
.test {
208+
repeat(10) {
209+
awaitItem().close()
210+
}
211+
awaitComplete()
212+
}
213+
}
214+
215+
@Test
216+
fun testStreamRequestN() = test {
217+
start(RSocketRequestHandler {
218+
requestStream {
219+
(0..9).asFlow().map { payload(it.toString()) }
220+
}
221+
})
222+
.requestStream(payload("HELLO"))
223+
.flowOn(PrefetchStrategy(5, 3))
224+
.test {
225+
repeat(10) { awaitItem().close() }
226+
awaitComplete()
227+
}
228+
}
229+
195230
@Test
196231
fun testChannel() = test {
197232
val awaiter = Job()
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
/*
2+
* Copyright 2015-2020 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.rsocket.kotlin.internal
18+
19+
import app.cash.turbine.FlowTurbine
20+
import io.rsocket.kotlin.*
21+
import io.rsocket.kotlin.frame.*
22+
import io.rsocket.kotlin.frame.io.Version
23+
import io.rsocket.kotlin.keepalive.DefaultKeepAlive
24+
import io.rsocket.kotlin.payload.DefaultPayloadMimeType
25+
import io.rsocket.kotlin.payload.buildPayload
26+
import io.rsocket.kotlin.payload.data
27+
import io.rsocket.kotlin.test.TestExceptionHandler
28+
import io.rsocket.kotlin.test.TestServer
29+
import io.rsocket.kotlin.test.TestWithLeakCheck
30+
import io.rsocket.kotlin.test.payload
31+
import io.rsocket.kotlin.transport.ServerTransport
32+
import kotlinx.coroutines.*
33+
import kotlinx.coroutines.channels.Channel
34+
import kotlinx.coroutines.flow.asFlow
35+
import kotlinx.coroutines.flow.map
36+
import kotlinx.coroutines.flow.onEach
37+
import kotlin.test.Test
38+
import kotlin.test.assertEquals
39+
import kotlin.test.assertTrue
40+
import kotlin.time.Duration.Companion.seconds
41+
42+
class RSocketResponderRequestNTest : TestWithLeakCheck, TestWithConnection() {
43+
private val testJob: Job = Job()
44+
45+
private suspend fun start(handler: RSocket) {
46+
val serverTransport = ServerTransport { accept ->
47+
GlobalScope.async { accept(connection) }
48+
}
49+
50+
val scope = CoroutineScope(Dispatchers.Unconfined + testJob + TestExceptionHandler)
51+
@Suppress("DeferredResultUnused")
52+
TestServer().bindIn(scope, serverTransport) {
53+
config.setupPayload.close()
54+
handler
55+
}
56+
}
57+
58+
override suspend fun after() {
59+
super.after()
60+
testJob.cancelAndJoin()
61+
}
62+
63+
private val setupFrame
64+
get() = SetupFrame(
65+
version = Version.Current,
66+
honorLease = false,
67+
keepAlive = DefaultKeepAlive,
68+
resumeToken = null,
69+
payloadMimeType = DefaultPayloadMimeType,
70+
payload = payload("setup"),
71+
)
72+
73+
@Test
74+
fun testStreamInitialEnoughToConsume() = test {
75+
start(
76+
RSocketRequestHandler {
77+
requestStream { payload ->
78+
payload.close()
79+
(0..9).asFlow().map { buildPayload { data("$it") } }
80+
}
81+
}
82+
)
83+
84+
connection.test {
85+
connection.sendToReceiver(setupFrame)
86+
87+
connection.sendToReceiver(RequestStreamFrame(initialRequestN = 16, streamId = 1, payload = payload("request")))
88+
89+
awaitAndReleasePayloadFrames(amount = 10)
90+
awaitCompleteFrame()
91+
expectNoEventsIn(200)
92+
}
93+
}
94+
95+
@Test
96+
fun testStreamSuspendWhenNoRequestsLeft() = test {
97+
var lastSent = -1
98+
start(
99+
RSocketRequestHandler {
100+
requestStream { payload ->
101+
payload.close()
102+
(0..9).asFlow()
103+
.onEach { lastSent = it }
104+
.map { buildPayload { data("$it") } }
105+
}
106+
}
107+
)
108+
109+
connection.test {
110+
connection.sendToReceiver(setupFrame)
111+
112+
connection.sendToReceiver(RequestStreamFrame(initialRequestN = 3, streamId = 1, payload = payload("request")))
113+
114+
awaitAndReleasePayloadFrames(amount = 3)
115+
expectNoEventsIn(200)
116+
assertEquals(3, lastSent)
117+
}
118+
}
119+
120+
@Test
121+
fun testStreamRequestNFrameResumesOperation() = test {
122+
start(
123+
RSocketRequestHandler {
124+
requestStream { payload ->
125+
payload.close()
126+
(0..15).asFlow().map { buildPayload { data("$it") } }
127+
}
128+
}
129+
)
130+
connection.test {
131+
connection.sendToReceiver(setupFrame)
132+
133+
connection.sendToReceiver(RequestStreamFrame(initialRequestN = 3, streamId = 1, payload = payload("request")))
134+
awaitAndReleasePayloadFrames(amount = 3)
135+
expectNoEventsIn(200)
136+
137+
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = 5))
138+
awaitAndReleasePayloadFrames(amount = 5)
139+
expectNoEventsIn(200)
140+
141+
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = 5))
142+
awaitAndReleasePayloadFrames(amount = 5)
143+
expectNoEventsIn(200)
144+
}
145+
}
146+
147+
@Test
148+
fun testStreamRequestNEnoughToComplete() = test {
149+
val total = 20
150+
start(
151+
RSocketRequestHandler {
152+
requestStream { payload ->
153+
payload.close()
154+
(0 until total).asFlow().map { buildPayload { data("$it") } }
155+
}
156+
}
157+
)
158+
connection.test {
159+
connection.sendToReceiver(setupFrame)
160+
161+
val firstRequest = 3
162+
connection.sendToReceiver(RequestStreamFrame(initialRequestN = firstRequest, streamId = 1, payload = payload("request")))
163+
awaitAndReleasePayloadFrames(amount = firstRequest)
164+
expectNoEventsIn(200)
165+
166+
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE))
167+
awaitAndReleasePayloadFrames(amount = total - firstRequest)
168+
awaitCompleteFrame()
169+
expectNoEventsIn(200)
170+
}
171+
}
172+
173+
@Test
174+
fun testStreamRequestNAttemptedIntOverflow() = test {
175+
val latch = Channel<Unit>(1)
176+
start(
177+
RSocketRequestHandler {
178+
requestStream { payload ->
179+
payload.close()
180+
latch.receive()
181+
// make sure limiter has got the RequestNFrame before emitting the values
182+
delay(200)
183+
(0..19).asFlow().map { buildPayload { data("$it") } }
184+
}
185+
}
186+
)
187+
connection.test {
188+
connection.sendToReceiver(setupFrame)
189+
190+
connection.sendToReceiver(RequestStreamFrame(initialRequestN = Int.MAX_VALUE, streamId = 1, payload = payload("request")))
191+
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE))
192+
latch.send(Unit)
193+
194+
awaitAndReleasePayloadFrames(amount = 20)
195+
awaitCompleteFrame()
196+
expectNoEventsIn(200)
197+
}
198+
}
199+
200+
201+
@Test
202+
fun testStreamRequestNSummingUpToOverflow() = test {
203+
val latch = Channel<Unit>(1)
204+
start(
205+
RSocketRequestHandler {
206+
requestStream { payload ->
207+
payload.close()
208+
latch.receive()
209+
// make sure limiter has got the RequestNFrame before emitting the values
210+
delay(200)
211+
(0..19).asFlow().map { buildPayload { data("$it") } }
212+
}
213+
}
214+
)
215+
216+
connection.test {
217+
connection.sendToReceiver(setupFrame)
218+
219+
connection.sendToReceiver(RequestStreamFrame(initialRequestN = 5, streamId = 1, payload = payload("request")))
220+
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE / 3))
221+
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE / 3))
222+
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE / 3))
223+
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE / 3))
224+
latch.send(Unit)
225+
226+
awaitAndReleasePayloadFrames(amount = 20)
227+
awaitCompleteFrame()
228+
expectNoEventsIn(200)
229+
}
230+
}
231+
232+
private suspend fun FlowTurbine<Frame>.awaitAndReleasePayloadFrames(amount: Int) {
233+
repeat(amount) {
234+
awaitFrame { frame ->
235+
assertTrue(frame is RequestFrame)
236+
assertEquals(FrameType.Payload, frame.type)
237+
frame.payload.close()
238+
}
239+
}
240+
}
241+
242+
private suspend fun FlowTurbine<Frame>.awaitCompleteFrame() {
243+
awaitFrame { frame ->
244+
assertTrue(frame is RequestFrame)
245+
assertEquals(FrameType.Payload, frame.type)
246+
assertTrue(frame.complete, "Frame should be complete")
247+
}
248+
}
249+
}

0 commit comments

Comments
 (0)