Skip to content

Rework request channel #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ RSocket interface contains 5 methods:
* Request-Stream:

`fun requestStream(payload: Payload): Flow<Payload>`
* Request-Channel:
* Request-Channel:

`fun requestChannel(payloads: Flow<Payload>): Flow<Payload>`
`fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload>`
* Metadata-Push:

`suspend fun metadataPush(metadata: ByteReadPacket)`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ class RSocketKotlinBenchmark : RSocketBenchmark<Payload>() {
it.release()
payloadsFlow
}
requestChannel { it.flowOn(requestStrategy) }
requestChannel { init, payloads ->
init.release()
payloads.flowOn(requestStrategy)
}
}
}
client = runBlocking {
Expand Down Expand Up @@ -80,6 +83,6 @@ class RSocketKotlinBenchmark : RSocketBenchmark<Payload>() {

override suspend fun doRequestStream(): Flow<Payload> = client.requestStream(payloadCopy()).flowOn(requestStrategy)

override suspend fun doRequestChannel(): Flow<Payload> = client.requestChannel(payloadsFlow).flowOn(requestStrategy)
override suspend fun doRequestChannel(): Flow<Payload> = client.requestChannel(payloadCopy(), payloadsFlow).flowOn(requestStrategy)

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ fun main(): Unit = runBlocking {
val server = LocalServer()
RSocketServer().bind(server) {
RSocketRequestHandler {
requestChannel { request ->
requestChannel { init, request ->
println("Init with: ${init.data.readText()}")
request.flowOn(PrefetchStrategy(3, 0)).take(3).flatMapConcat { payload ->
val data = payload.data.readText()
flow {
Expand All @@ -50,7 +51,7 @@ fun main(): Unit = runBlocking {
println("Client: No") //no print
}

val response = rSocket.requestChannel(request)
val response = rSocket.requestChannel(Payload("Init"), request)
response.collect {
val data = it.data.readText()
println("Client receives: $data")
Expand Down
2 changes: 1 addition & 1 deletion examples/multiplatform-chat/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ val kotlinxSerializationVersion: String by rootProject
kotlin {
jvm("serverJvm")
jvm("clientJvm")
js("clientJs", LEGACY) {
js("clientJs", IR) {
browser {
binaries.executable()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ interface RSocket : Cancellable {
notImplemented("Request Stream")
}

fun requestChannel(payloads: Flow<Payload>): Flow<Payload> {
fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> {
initPayload.release()
notImplemented("Request Channel")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import io.rsocket.kotlin.payload.*
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*

class RSocketRequestHandlerBuilder internal constructor() {
public class RSocketRequestHandlerBuilder internal constructor() {
private var metadataPush: (suspend RSocket.(metadata: ByteReadPacket) -> Unit)? = null
private var fireAndForget: (suspend RSocket.(payload: Payload) -> Unit)? = null
private var requestResponse: (suspend RSocket.(payload: Payload) -> Payload)? = null
private var requestStream: (RSocket.(payload: Payload) -> Flow<Payload>)? = null
private var requestChannel: (RSocket.(payloads: Flow<Payload>) -> Flow<Payload>)? = null
private var requestChannel: (RSocket.(initPayload: Payload, payloads: Flow<Payload>) -> Flow<Payload>)? = null

public fun metadataPush(block: (suspend RSocket.(metadata: ByteReadPacket) -> Unit)) {
check(metadataPush == null) { "Metadata Push handler already configured" }
Expand All @@ -48,7 +48,7 @@ class RSocketRequestHandlerBuilder internal constructor() {
requestStream = block
}

public fun requestChannel(block: (RSocket.(payloads: Flow<Payload>) -> Flow<Payload>)) {
public fun requestChannel(block: (RSocket.(initPayload: Payload, payloads: Flow<Payload>) -> Flow<Payload>)) {
check(requestChannel == null) { "Request Channel handler already configured" }
requestChannel = block
}
Expand All @@ -58,7 +58,7 @@ class RSocketRequestHandlerBuilder internal constructor() {
}

@Suppress("FunctionName")
fun RSocketRequestHandler(parentJob: Job? = null, configure: RSocketRequestHandlerBuilder.() -> Unit): RSocket {
public fun RSocketRequestHandler(parentJob: Job? = null, configure: RSocketRequestHandlerBuilder.() -> Unit): RSocket {
val builder = RSocketRequestHandlerBuilder()
builder.configure()
return builder.build(Job(parentJob))
Expand All @@ -70,7 +70,7 @@ private class RSocketRequestHandler(
private val fireAndForget: (suspend RSocket.(payload: Payload) -> Unit)? = null,
private val requestResponse: (suspend RSocket.(payload: Payload) -> Payload)? = null,
private val requestStream: (RSocket.(payload: Payload) -> Flow<Payload>)? = null,
private val requestChannel: (RSocket.(payloads: Flow<Payload>) -> Flow<Payload>)? = null,
private val requestChannel: (RSocket.(initPayload: Payload, payloads: Flow<Payload>) -> Flow<Payload>)? = null,
) : RSocket {
override suspend fun metadataPush(metadata: ByteReadPacket): Unit =
metadataPush?.invoke(this, metadata) ?: super.metadataPush(metadata)
Expand All @@ -84,7 +84,7 @@ private class RSocketRequestHandler(
override fun requestStream(payload: Payload): Flow<Payload> =
requestStream?.invoke(this, payload) ?: super.requestStream(payload)

override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> =
requestChannel?.invoke(this, payloads) ?: super.requestChannel(payloads)
override fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> =
requestChannel?.invoke(this, initPayload, payloads) ?: super.requestChannel(initPayload, payloads)

}
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ private class ReconnectableRSocket(
emitAll(currentRSocket(payload).requestStream(payload))
}

override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> = flow {
emitAll(currentRSocket().requestChannel(payloads))
override fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> = flow {
emitAll(currentRSocket(initPayload).requestChannel(initPayload, payloads))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ internal class RSocketRequester(

override fun requestStream(payload: Payload): Flow<Payload> = RequestStreamRequesterFlow(payload, this, state)

override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> = RequestChannelRequesterFlow(payloads, this, state)
override fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> =
RequestChannelRequesterFlow(initPayload, payloads, this, state)

fun createStream(): Int {
checkAvailable()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,31 +60,23 @@ internal class RSocketResponder(
val response = requestOrCancel(streamId) {
requestHandler.requestStream(initFrame.payload)
} ?: return@launchCancelable
response.collectLimiting(
streamId,
RequestStreamResponderFlowCollector(state, streamId, initFrame.initialRequest)
)
send(CompletePayloadFrame(streamId))
response.collectLimiting(streamId, initFrame.initialRequest)
}.invokeOnCompletion {
initFrame.release()
}
}

fun handleRequestChannel(initFrame: RequestFrame): Unit = with(state) {
val streamId = initFrame.streamId
val receiver = createReceiverFor(streamId, initFrame)
val receiver = createReceiverFor(streamId)

val request = RequestChannelResponderFlow(streamId, receiver, state)

launchCancelable(streamId) {
val response = requestOrCancel(streamId) {
requestHandler.requestChannel(request)
requestHandler.requestChannel(initFrame.payload, request)
} ?: return@launchCancelable
response.collectLimiting(
streamId,
RequestStreamResponderFlowCollector(state, streamId, initFrame.initialRequest)
)
send(CompletePayloadFrame(streamId))
response.collectLimiting(streamId, initFrame.initialRequest)
}.invokeOnCompletion {
initFrame.release()
receiver.closeReceivedElements()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ internal class RSocketState(
prioritizer.sendPrioritized(frame)
}

fun createReceiverFor(streamId: Int, initFrame: RequestFrame? = null): ReceiveChannel<RequestFrame> {
fun createReceiverFor(streamId: Int): ReceiveChannel<RequestFrame> {
val receiver = SafeChannel<RequestFrame>(Channel.UNLIMITED)
initFrame?.let(receiver::offer) //used only in RequestChannel on responder side
receivers[streamId] = receiver
return receiver
}
Expand Down Expand Up @@ -94,11 +93,15 @@ internal class RSocketState(

suspend inline fun Flow<Payload>.collectLimiting(
streamId: Int,
limitingCollector: LimitingFlowCollector,
initialRequest: Int,
crossinline onStart: () -> Unit = {},
): Unit = coroutineScope {
val limitingCollector = LimitingFlowCollector(this@RSocketState, streamId, initialRequest)
limits[streamId] = limitingCollector
try {
onStart()
collect(limitingCollector)
send(CompletePayloadFrame(streamId))
} catch (e: Throwable) {
limits.remove(streamId)
//if isn't active, then, that stream was cancelled, and so no need for error frame
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,30 @@

package io.rsocket.kotlin.internal.flow

import io.rsocket.kotlin.frame.*
import io.rsocket.kotlin.internal.*
import io.rsocket.kotlin.payload.*
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*

internal abstract class LimitingFlowCollector(initial: Int) : FlowCollector<Payload> {
internal class LimitingFlowCollector(
private val state: RSocketState,
private val streamId: Int,
initial: Int,
) : FlowCollector<Payload> {
private val requests = atomic(initial)
private val awaiter = atomic<CancellableContinuation<Unit>?>(null)

abstract suspend fun emitValue(value: Payload)

fun updateRequests(n: Int) {
if (n <= 0) return
requests.getAndAdd(n)
awaiter.getAndSet(null)?.resumeSafely()
}

final override suspend fun emit(value: Payload): Unit = value.closeOnError {
override suspend fun emit(value: Payload): Unit = value.closeOnError {
useRequest()
emitValue(value)
state.send(NextPayloadFrame(streamId, value))
}

private suspend fun useRequest() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@ package io.rsocket.kotlin.internal.flow
import io.rsocket.kotlin.*
import io.rsocket.kotlin.frame.*
import io.rsocket.kotlin.internal.*
import io.rsocket.kotlin.internal.cancelConsumed
import io.rsocket.kotlin.payload.*
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.*

@OptIn(ExperimentalStreamsApi::class)
@OptIn(ExperimentalStreamsApi::class, ExperimentalCoroutinesApi::class)
internal class RequestChannelRequesterFlow(
private val initPayload: Payload,
private val payloads: Flow<Payload>,
private val requester: RSocketRequester,
private val state: RSocketState,
Expand All @@ -40,31 +39,25 @@ internal class RequestChannelRequesterFlow(

val strategy = currentCoroutineContext().requestStrategy()
val initialRequest = strategy.firstRequest()
val streamId = requester.createStream()
val receiverDeferred = CompletableDeferred<ReceiveChannel<RequestFrame>?>()
val request = launchCancelable(streamId) {
payloads.collectLimiting(
streamId,
RequestChannelRequesterFlowCollector(state, streamId, receiverDeferred, initialRequest)
)
if (receiverDeferred.isCompleted && !receiverDeferred.isCancelled) send(CompletePayloadFrame(streamId))
}
request.invokeOnCompletion {
if (receiverDeferred.isCompleted) {
@OptIn(ExperimentalCoroutinesApi::class)
if (it != null && it !is CancellationException) receiverDeferred.getCompleted()?.cancelConsumed(it)
} else {
if (it == null) receiverDeferred.complete(null)
else receiverDeferred.completeExceptionally(it.cause ?: it)
initPayload.closeOnError {
val streamId = requester.createStream()
val receiver = createReceiverFor(streamId)
val request = launchCancelable(streamId) {
payloads.collectLimiting(streamId, 0) {
send(RequestChannelFrame(streamId, initialRequest, initPayload))
}
}

request.invokeOnCompletion {
if (it != null && it !is CancellationException) receiver.cancelConsumed(it)
}
try {
collectStream(streamId, receiver, strategy, collector)
} catch (e: Throwable) {
if (e is CancellationException) request.cancel(e)
else request.cancel("Receiver failed", e)
throw e
}
}
try {
val receiver = receiverDeferred.await() ?: return
collectStream(streamId, receiver, strategy, collector)
} catch (e: Throwable) {
if (e is CancellationException) request.cancel(e)
else request.cancel("Receiver failed", e)
throw e
}
}
}

This file was deleted.

This file was deleted.

Loading