Skip to content

changed Limiter.requests to AtomicLong #214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,50 @@ internal suspend inline fun Flow<Payload>.collectLimiting(limiter: Limiter, cros
}
}

//TODO revisit 2 atomics and sync object
/**
* Maintains the amount of requests which the client is ready to consume and
* prevents sending further updates by suspending the sending coroutine
* if this amount reaches 0.
*
* ### Operation
*
* Each [useRequest] call decrements the maintained requests amount.
* Calling coroutine is suspended when this amount reaches 0.
* The coroutine is resumed when [updateRequests] is called.
*
*/
internal class Limiter(initial: Int) : SynchronizedObject() {
private val requests = atomic(initial)
private val awaiter = atomic<CancellableContinuation<Unit>?>(null)
private val requests: AtomicLong = atomic(initial.toLong())
private var awaiter: CancellableContinuation<Unit>? = null

fun updateRequests(n: Int) {
if (n <= 0) return
synchronized(this) {
requests += n
awaiter.getAndSet(null)?.takeIf(CancellableContinuation<Unit>::isActive)?.resume(Unit)
val updatedRequests = requests.value + n.toLong()
if (updatedRequests < 0) {
requests.value = Long.MAX_VALUE
} else {
requests.value = updatedRequests
}

if (awaiter?.isActive == true) {
awaiter?.resume(Unit)
awaiter = null
}
}
}

suspend fun useRequest() {
if (requests.getAndDecrement() > 0) {
if (requests.decrementAndGet() >= 0) {
currentCoroutineContext().ensureActive()
} else {
suspendCancellableCoroutine<Unit> {
suspendCancellableCoroutine<Unit> { continuation ->
synchronized(this) {
awaiter.value = it
if (requests.value >= 0 && it.isActive) it.resume(Unit)
if (requests.value >= 0 && continuation.isActive) {
continuation.resume(Unit)
} else {
this.awaiter = continuation
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import io.rsocket.kotlin.keepalive.*
import io.rsocket.kotlin.payload.*
import io.rsocket.kotlin.test.*
import io.rsocket.kotlin.transport.local.*
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.*
Expand Down Expand Up @@ -192,6 +193,40 @@ class RSocketTest : SuspendTest, TestWithLeakCheck {
assertTrue(channel.receiveCatching().isClosed)
}

@Test
fun testStreamInitialMaxValue() = test {
val requester = start(RSocketRequestHandler {
requestStream {
(0..9).asFlow().map {
payload(it.toString())
}
}
})
requester.requestStream(payload("HELLO"))
.flowOn(PrefetchStrategy(Int.MAX_VALUE, 0))
.test {
repeat(10) {
awaitItem().close()
}
awaitComplete()
}
}

@Test
fun testStreamRequestN() = test {
start(RSocketRequestHandler {
requestStream {
(0..9).asFlow().map { payload(it.toString()) }
}
})
.requestStream(payload("HELLO"))
.flowOn(PrefetchStrategy(5, 3))
.test {
repeat(10) { awaitItem().close() }
awaitComplete()
}
}

@Test
fun testChannel() = test {
val awaiter = Job()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
/*
* Copyright 2015-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.rsocket.kotlin.internal

import app.cash.turbine.FlowTurbine
import io.rsocket.kotlin.*
import io.rsocket.kotlin.frame.*
import io.rsocket.kotlin.frame.io.Version
import io.rsocket.kotlin.keepalive.DefaultKeepAlive
import io.rsocket.kotlin.payload.DefaultPayloadMimeType
import io.rsocket.kotlin.payload.buildPayload
import io.rsocket.kotlin.payload.data
import io.rsocket.kotlin.test.TestExceptionHandler
import io.rsocket.kotlin.test.TestServer
import io.rsocket.kotlin.test.TestWithLeakCheck
import io.rsocket.kotlin.test.payload
import io.rsocket.kotlin.transport.ServerTransport
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import kotlin.time.Duration.Companion.seconds

class RSocketResponderRequestNTest : TestWithLeakCheck, TestWithConnection() {
private val testJob: Job = Job()

private suspend fun start(handler: RSocket) {
val serverTransport = ServerTransport { accept ->
GlobalScope.async { accept(connection) }
}

val scope = CoroutineScope(Dispatchers.Unconfined + testJob + TestExceptionHandler)
@Suppress("DeferredResultUnused")
TestServer().bindIn(scope, serverTransport) {
config.setupPayload.close()
handler
}
}

override suspend fun after() {
super.after()
testJob.cancelAndJoin()
}

private val setupFrame
get() = SetupFrame(
version = Version.Current,
honorLease = false,
keepAlive = DefaultKeepAlive,
resumeToken = null,
payloadMimeType = DefaultPayloadMimeType,
payload = payload("setup"),
)

@Test
fun testStreamInitialEnoughToConsume() = test {
start(
RSocketRequestHandler {
requestStream { payload ->
payload.close()
(0..9).asFlow().map { buildPayload { data("$it") } }
}
}
)

connection.test {
connection.sendToReceiver(setupFrame)

connection.sendToReceiver(RequestStreamFrame(initialRequestN = 16, streamId = 1, payload = payload("request")))

awaitAndReleasePayloadFrames(amount = 10)
awaitCompleteFrame()
expectNoEventsIn(200)
}
}

@Test
fun testStreamSuspendWhenNoRequestsLeft() = test {
var lastSent = -1
start(
RSocketRequestHandler {
requestStream { payload ->
payload.close()
(0..9).asFlow()
.onEach { lastSent = it }
.map { buildPayload { data("$it") } }
}
}
)

connection.test {
connection.sendToReceiver(setupFrame)

connection.sendToReceiver(RequestStreamFrame(initialRequestN = 3, streamId = 1, payload = payload("request")))

awaitAndReleasePayloadFrames(amount = 3)
expectNoEventsIn(200)
assertEquals(3, lastSent)
}
}

@Test
fun testStreamRequestNFrameResumesOperation() = test {
start(
RSocketRequestHandler {
requestStream { payload ->
payload.close()
(0..15).asFlow().map { buildPayload { data("$it") } }
}
}
)
connection.test {
connection.sendToReceiver(setupFrame)

connection.sendToReceiver(RequestStreamFrame(initialRequestN = 3, streamId = 1, payload = payload("request")))
awaitAndReleasePayloadFrames(amount = 3)
expectNoEventsIn(200)

connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = 5))
awaitAndReleasePayloadFrames(amount = 5)
expectNoEventsIn(200)

connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = 5))
awaitAndReleasePayloadFrames(amount = 5)
expectNoEventsIn(200)
}
}

@Test
fun testStreamRequestNEnoughToComplete() = test {
val total = 20
start(
RSocketRequestHandler {
requestStream { payload ->
payload.close()
(0 until total).asFlow().map { buildPayload { data("$it") } }
}
}
)
connection.test {
connection.sendToReceiver(setupFrame)

val firstRequest = 3
connection.sendToReceiver(RequestStreamFrame(initialRequestN = firstRequest, streamId = 1, payload = payload("request")))
awaitAndReleasePayloadFrames(amount = firstRequest)
expectNoEventsIn(200)

connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE))
awaitAndReleasePayloadFrames(amount = total - firstRequest)
awaitCompleteFrame()
expectNoEventsIn(200)
}
}

@Test
fun testStreamRequestNAttemptedIntOverflow() = test {
val latch = Channel<Unit>(1)
start(
RSocketRequestHandler {
requestStream { payload ->
payload.close()
latch.receive()
// make sure limiter has got the RequestNFrame before emitting the values
delay(200)
(0..19).asFlow().map { buildPayload { data("$it") } }
}
}
)
connection.test {
connection.sendToReceiver(setupFrame)

connection.sendToReceiver(RequestStreamFrame(initialRequestN = Int.MAX_VALUE, streamId = 1, payload = payload("request")))
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE))
latch.send(Unit)

awaitAndReleasePayloadFrames(amount = 20)
awaitCompleteFrame()
expectNoEventsIn(200)
}
}


@Test
fun testStreamRequestNSummingUpToOverflow() = test {
val latch = Channel<Unit>(1)
start(
RSocketRequestHandler {
requestStream { payload ->
payload.close()
latch.receive()
// make sure limiter has got the RequestNFrame before emitting the values
delay(200)
(0..19).asFlow().map { buildPayload { data("$it") } }
}
}
)

connection.test {
connection.sendToReceiver(setupFrame)

connection.sendToReceiver(RequestStreamFrame(initialRequestN = 5, streamId = 1, payload = payload("request")))
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE / 3))
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE / 3))
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE / 3))
connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE / 3))
latch.send(Unit)

awaitAndReleasePayloadFrames(amount = 20)
awaitCompleteFrame()
expectNoEventsIn(200)
}
}

private suspend fun FlowTurbine<Frame>.awaitAndReleasePayloadFrames(amount: Int) {
repeat(amount) {
awaitFrame { frame ->
assertTrue(frame is RequestFrame)
assertEquals(FrameType.Payload, frame.type)
frame.payload.close()
}
}
}

private suspend fun FlowTurbine<Frame>.awaitCompleteFrame() {
awaitFrame { frame ->
assertTrue(frame is RequestFrame)
assertEquals(FrameType.Payload, frame.type)
assertTrue(frame.complete, "Frame should be complete")
}
}
}