Skip to content

Commit

Permalink
Merge pull request #2 from neandertech/fix-race-condition
Browse files Browse the repository at this point in the history
Adds a function to manually open the channel
  • Loading branch information
Baccata authored Aug 5, 2022
2 parents 8bc0e75 + 3d83715 commit 9f9e974
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 12 deletions.
26 changes: 18 additions & 8 deletions fs2/src/jsonrpclib/fs2/FS2Channel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,43 @@ import jsonrpclib.internals.MessageDispatcher
import jsonrpclib.internals._

import scala.util.Try
import _root_.fs2.concurrent.SignallingRef

trait FS2Channel[F[_]] extends Channel[F] {
def withEndpoint(endpoint: Endpoint[F])(implicit F: Functor[F]): Resource[F, Unit] =
Resource.make(mountEndpoint(endpoint))(_ => unmountEndpoint(endpoint.method))

def withEndpoints(endpoint: Endpoint[F], rest: Endpoint[F]*)(implicit F: Monad[F]): Resource[F, Unit] =
(endpoint :: rest.toList).traverse_(withEndpoint)

def open: Resource[F, Unit]
def openStream: Stream[F, Unit]
}

object FS2Channel {

def lspCompliant[F[_]: Concurrent](
byteStream: Stream[F, Byte],
byteSink: Pipe[F, Byte, Nothing],
startingEndpoints: List[Endpoint[F]] = List.empty,
bufferSize: Int = 512
): Stream[F, FS2Channel[F]] = internals.LSP.writeSink(byteSink, bufferSize).flatMap { sink =>
apply[F](internals.LSP.readStream(byteStream), sink, startingEndpoints)
apply[F](internals.LSP.readStream(byteStream), sink)
}

def apply[F[_]: Concurrent](
payloadStream: Stream[F, Payload],
payloadSink: Payload => F[Unit],
startingEndpoints: List[Endpoint[F]] = List.empty[Endpoint[F]]
payloadSink: Payload => F[Unit]
): Stream[F, FS2Channel[F]] = {
val endpointsMap = startingEndpoints.map(ep => ep.method -> ep).toMap
for {
supervisor <- Stream.resource(Supervisor[F])
ref <- Ref[F].of(State[F](Map.empty, endpointsMap, 0)).toStream
impl = new Impl(payloadSink, ref, supervisor)
_ <- Stream(()).concurrently(payloadStream.evalMap(impl.handleReceivedPayload))
ref <- Ref[F].of(State[F](Map.empty, Map.empty, 0)).toStream
isOpen <- SignallingRef[F].of(false).toStream
awaitingSink = isOpen.waitUntil(identity) >> payloadSink(_: Payload)
impl = new Impl(awaitingSink, ref, isOpen, supervisor)
_ <- Stream(()).concurrently {
// Gatekeeping the pull until the channel is actually marked as open
payloadStream.pauseWhen(isOpen.map(b => !b)).evalMap(impl.handleReceivedPayload)
}
} yield impl
}

Expand Down Expand Up @@ -72,6 +78,7 @@ object FS2Channel {
private class Impl[F[_]](
private val sink: Payload => F[Unit],
private val state: Ref[F, FS2Channel.State[F]],
private val isOpen: SignallingRef[F, Boolean],
supervisor: Supervisor[F]
)(implicit F: Concurrent[F])
extends MessageDispatcher[F]
Expand All @@ -88,6 +95,9 @@ object FS2Channel {

def unmountEndpoint(method: String): F[Unit] = state.update(_.removeEndpoint(method))

def open: Resource[F, Unit] = Resource.make[F, Unit](isOpen.set(true))(_ => isOpen.set(false))
def openStream: Stream[F, Unit] = Stream.resource(open)

protected def background[A](fa: F[A]): F[Unit] = supervisor.supervise(fa).void
protected def reportError(params: Option[Payload], error: ProtocolError, method: String): F[Unit] = ???
protected def getEndpoint(method: String): F[Option[Endpoint[F]]] = state.get.map(_.endpoints.get(method))
Expand Down
6 changes: 6 additions & 0 deletions fs2/src/jsonrpclib/fs2/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@ package jsonrpclib
import _root_.fs2.Stream
import cats.MonadThrow
import cats.Monad
import cats.effect.kernel.Resource
import cats.effect.kernel.MonadCancel

package object fs2 {

private[jsonrpclib] implicit class EffectOps[F[_], A](private val fa: F[A]) extends AnyVal {
def toStream: Stream[F, A] = Stream.eval(fa)
}

private[jsonrpclib] implicit class ResourceOps[F[_], A](private val fa: Resource[F, A]) extends AnyVal {
def asStream(implicit F: MonadCancel[F, Throwable]): Stream[F, A] = Stream.resource(fa)
}

implicit def catsMonadic[F[_]: MonadThrow]: Monadic[F] = new Monadic[F] {
def doFlatMap[A, B](fa: F[A])(f: A => F[B]): F[B] = Monad[F].flatMap(fa)(f)

Expand Down
14 changes: 10 additions & 4 deletions fs2/test/src/jsonrpclib/fs2/FS2ChannelSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ object FS2ChannelSpec extends SimpleIOSuite {
}

def testRes(name: TestName)(run: Stream[IO, Expectations]): Unit =
test(name)(run.compile.lastOrError)
test(name)(run.compile.lastOrError.timeout(10.second))

testRes("Round trip") {
val endpoint: Endpoint[IO] = Endpoint[IO]("inc").simple((int: IntWrapper) => IO(IntWrapper(int.int + 1)))
Expand All @@ -31,8 +31,10 @@ object FS2ChannelSpec extends SimpleIOSuite {
stdin <- Queue.bounded[IO, Payload](10).toStream
serverSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), stdout.offer)
clientSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdout), stdin.offer)
_ <- Stream.resource(serverSideChannel.withEndpoint(endpoint))
_ <- serverSideChannel.withEndpoint(endpoint).asStream
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
_ <- serverSideChannel.open.asStream
_ <- clientSideChannel.open.asStream
result <- remoteFunction(IntWrapper(1)).toStream
} yield {
expect.same(result, IntWrapper(2))
Expand All @@ -44,9 +46,11 @@ object FS2ChannelSpec extends SimpleIOSuite {
for {
stdout <- Queue.bounded[IO, Payload](10).toStream
stdin <- Queue.bounded[IO, Payload](10).toStream
_ <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), stdout.offer)
serverSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), stdout.offer)
clientSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdout), stdin.offer)
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
_ <- serverSideChannel.open.asStream
_ <- clientSideChannel.open.asStream
result <- remoteFunction(IntWrapper(1)).attempt.toStream
} yield {
expect.same(result, Left(ErrorPayload(-32601, "Method inc not found", None)))
Expand All @@ -65,8 +69,10 @@ object FS2ChannelSpec extends SimpleIOSuite {
stdin <- Queue.bounded[IO, Payload](10).toStream
serverSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), payload => stdout.offer(payload))
clientSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdout), payload => stdin.offer(payload))
_ <- Stream.resource(serverSideChannel.withEndpoint(endpoint))
_ <- serverSideChannel.withEndpoint(endpoint).asStream
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
_ <- serverSideChannel.open.asStream
_ <- clientSideChannel.open.asStream
timedResults <- (1 to 10).toList.map(IntWrapper(_)).parTraverse(remoteFunction).timed.toStream
} yield {
val (time, results) = timedResults
Expand Down

0 comments on commit 9f9e974

Please sign in to comment.