Skip to content

Commit

Permalink
Changed the FS2Channel to have input/output streams based on Message
Browse files Browse the repository at this point in the history
Also, changed the Endpoint interfaces to carry over the InputMessage
as parameters of their run functions. This should facilitate the
proxification of servers, as the implementor of the proxy endpoint will
be able to shove the input message onto the underlying's server's
message queue.
  • Loading branch information
Baccata committed Dec 28, 2022
1 parent 88b109a commit 1108514
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 39 deletions.
17 changes: 10 additions & 7 deletions core/src/jsonrpclib/Endpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ object Endpoint {
def apply[In, Err, Out](
run: In => F[Either[Err, Out]]
)(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): Endpoint[F] =
RequestResponseEndpoint(method, (_: Method, in: In) => run(in), inCodec, errCodec, outCodec)
RequestResponseEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec, errCodec, outCodec)

def full[In, Err, Out](
run: (Method, In) => F[Either[Err, Out]]
run: (InputMessage, In) => F[Either[Err, Out]]
)(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): Endpoint[F] =
RequestResponseEndpoint(method, run, inCodec, errCodec, outCodec)

Expand All @@ -33,19 +33,22 @@ object Endpoint {
)

def notification[In](run: In => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] =
NotificationEndpoint(method, (_: Method, in: In) => run(in), inCodec)
NotificationEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec)

def notificationFull[In](run: (Method, In) => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] =
def notificationFull[In](run: (InputMessage, In) => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] =
NotificationEndpoint(method, run, inCodec)

}

final case class NotificationEndpoint[F[_], In](method: Method, run: (Method, In) => F[Unit], inCodec: Codec[In])
extends Endpoint[F]
final case class NotificationEndpoint[F[_], In](
method: MethodPattern,
run: (InputMessage, In) => F[Unit],
inCodec: Codec[In]
) extends Endpoint[F]

final case class RequestResponseEndpoint[F[_], In, Err, Out](
method: Method,
run: (Method, In) => F[Either[Err, Out]],
run: (InputMessage, In) => F[Either[Err, Out]],
inCodec: Codec[In],
errCodec: ErrorCodec[Err],
outCodec: Codec[Out]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,34 +1,32 @@
package jsonrpclib
package internals

import com.github.plokhotnyuk.jsoniter_scala.core.JsonReader
import com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec
import com.github.plokhotnyuk.jsoniter_scala.core.JsonWriter

sealed trait Message { def maybeCallId: Option[CallId] }
private[jsonrpclib] sealed trait InputMessage extends Message { def method: String }
private[jsonrpclib] sealed trait OutputMessage extends Message {
sealed trait InputMessage extends Message { def method: String }
sealed trait OutputMessage extends Message {
def callId: CallId; final override def maybeCallId: Option[CallId] = Some(callId)
}

private[jsonrpclib] object InputMessage {
object InputMessage {
case class RequestMessage(method: String, callId: CallId, params: Option[Payload]) extends InputMessage {
def maybeCallId: Option[CallId] = Some(callId)
}
case class NotificationMessage(method: String, params: Option[Payload]) extends InputMessage {
def maybeCallId: Option[CallId] = None
}
}

private[jsonrpclib] object OutputMessage {
object OutputMessage {
def errorFrom(callId: CallId, protocolError: ProtocolError): OutputMessage =
ErrorMessage(callId, ErrorPayload(protocolError.code, protocolError.getMessage(), None))

case class ErrorMessage(callId: CallId, payload: ErrorPayload) extends OutputMessage
case class ResponseMessage(callId: CallId, data: Payload) extends OutputMessage
}

private[jsonrpclib] object Message {
object Message {

implicit val messageJsonValueCodecs: JsonValueCodec[Message] = new JsonValueCodec[Message] {
val rawMessageCodec = implicitly[JsonValueCodec[internals.RawMessage]]
Expand Down
21 changes: 8 additions & 13 deletions core/src/jsonrpclib/internals/MessageDispatcher.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package jsonrpclib
package internals

import jsonrpclib.internals._
import jsonrpclib.Endpoint.NotificationEndpoint
import jsonrpclib.Endpoint.RequestResponseEndpoint
import jsonrpclib.internals.OutputMessage.ErrorMessage
import jsonrpclib.internals.OutputMessage.ResponseMessage
import jsonrpclib.OutputMessage.ErrorMessage
import jsonrpclib.OutputMessage.ResponseMessage
import scala.util.Try

private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F]) extends Channel.MonadicChannel[F] {
Expand Down Expand Up @@ -41,8 +40,8 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F
}
}

protected[jsonrpclib] def handleReceivedPayload(payload: Payload): F[Unit] = {
Codec.decode[Message](Some(payload)).map {
protected[jsonrpclib] def handleReceivedMessage(message: Message): F[Unit] = {
message match {
case im: InputMessage =>
doFlatMap(getEndpoint(im.method)) {
case Some(ep) => background(im.maybeCallId, executeInputMessage(im, ep))
Expand All @@ -61,29 +60,25 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F
case Some(pendingCall) => pendingCall(om)
case None => doPure(()) // TODO do something
}
} match {
case Left(error) =>
sendProtocolError(error)
case Right(dispatch) => dispatch
}
}

private def sendProtocolError(callId: CallId, pError: ProtocolError): F[Unit] =
protected def sendProtocolError(callId: CallId, pError: ProtocolError): F[Unit] =
sendMessage(OutputMessage.errorFrom(callId, pError))
private def sendProtocolError(pError: ProtocolError): F[Unit] =
protected def sendProtocolError(pError: ProtocolError): F[Unit] =
sendProtocolError(CallId.NullId, pError)

private def executeInputMessage(input: InputMessage, endpoint: Endpoint[F]): F[Unit] = {
(input, endpoint) match {
case (InputMessage.NotificationMessage(_, params), ep: NotificationEndpoint[F, in]) =>
ep.inCodec.decode(params) match {
case Right(value) => ep.run(input.method, value)
case Right(value) => ep.run(input, value)
case Left(value) => reportError(params, value, ep.method)
}
case (InputMessage.RequestMessage(_, callId, params), ep: RequestResponseEndpoint[F, in, err, out]) =>
ep.inCodec.decode(params) match {
case Right(value) =>
doFlatMap(ep.run(input.method, value)) {
doFlatMap(ep.run(input, value)) {
case Right(data) =>
val responseData = ep.outCodec.encode(data)
sendMessage(OutputMessage.ResponseMessage(callId, responseData))
Expand Down
4 changes: 2 additions & 2 deletions examples/client/src/examples/client/ClientMain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ object ClientMain extends IOApp.Simple {
// Creating a channel that will be used to communicate to the server
fs2Channel <- FS2Channel[IO](cancelTemplate = cancelEndpoint.some)
_ <- Stream(())
.concurrently(fs2Channel.output.through(lsp.encodePayloads).through(rp.stdin))
.concurrently(rp.stdout.through(lsp.decodePayloads).through(fs2Channel.input))
.concurrently(fs2Channel.output.through(lsp.encodeMessages).through(rp.stdin))
.concurrently(rp.stdout.through(lsp.decodeMessages).through(fs2Channel.inputOrBounce))
.concurrently(rp.stderr.through(fs2.io.stderr[IO]))
// Creating a `IntWrapper => IO[IntWrapper]` stub that can call the server
increment = fs2Channel.simpleStub[IntWrapper, IntWrapper]("increment")
Expand Down
4 changes: 2 additions & 2 deletions examples/server/src/examples/server/ServerMain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ object ServerMain extends IOApp.Simple {
.flatMap(channel =>
fs2.Stream
.eval(IO.never) // running the server forever
.concurrently(stdin[IO](512).through(lsp.decodePayloads).through(channel.input))
.concurrently(channel.output.through(lsp.encodePayloads).through(stdout[IO]))
.concurrently(stdin[IO](512).through(lsp.decodeMessages).through(channel.inputOrBounce))
.concurrently(channel.output.through(lsp.encodeMessages).through(stdout[IO]))
)
.compile
.drain
Expand Down
20 changes: 12 additions & 8 deletions fs2/src/jsonrpclib/fs2/FS2Channel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ import cats.effect.std.Supervisor
import cats.syntax.all._
import cats.effect.syntax.all._
import jsonrpclib.internals.MessageDispatcher
import jsonrpclib.internals._

import scala.util.Try
import java.util.regex.Pattern

trait FS2Channel[F[_]] extends Channel[F] {

def input: Pipe[F, Payload, Unit]
def output: Stream[F, Payload]
def input: Pipe[F, Message, Unit]
def inputOrBounce: Pipe[F, Either[ProtocolError, Message], Unit]
def output: Stream[F, Message]

def withEndpoint(endpoint: Endpoint[F])(implicit F: Functor[F]): Resource[F, FS2Channel[F]] =
Resource.make(mountEndpoint(endpoint))(_ => unmountEndpoint(endpoint.method)).map(_ => this)
Expand Down Expand Up @@ -54,7 +54,7 @@ object FS2Channel {
for {
supervisor <- Stream.resource(Supervisor[F])
ref <- Ref[F].of(State[F](Map.empty, Map.empty, Map.empty, Vector.empty, 0)).toStream
queue <- cats.effect.std.Queue.bounded[F, Payload](bufferSize).toStream
queue <- cats.effect.std.Queue.bounded[F, Message](bufferSize).toStream
impl = new Impl(queue, ref, supervisor, cancelTemplate)

// Creating a bespoke endpoint to receive cancelation requests
Expand Down Expand Up @@ -116,16 +116,20 @@ object FS2Channel {
}

private class Impl[F[_]](
private val queue: cats.effect.std.Queue[F, Payload],
private val queue: cats.effect.std.Queue[F, Message],
private val state: Ref[F, FS2Channel.State[F]],
supervisor: Supervisor[F],
maybeCancelTemplate: Option[CancelTemplate]
)(implicit F: Concurrent[F])
extends MessageDispatcher[F]
with FS2Channel[F] {

def output: Stream[F, Payload] = Stream.fromQueueUnterminated(queue)
def input: Pipe[F, Payload, Unit] = _.evalMap(handleReceivedPayload)
def output: Stream[F, Message] = Stream.fromQueueUnterminated(queue)
def inputOrBounce: Pipe[F, Either[ProtocolError, Message], Unit] = _.evalMap {
case Left(error) => sendProtocolError(error)
case Right(message) => handleReceivedMessage(message)
}
def input: Pipe[F, Message, Unit] = _.evalMap(handleReceivedMessage)

def mountEndpoint(endpoint: Endpoint[F]): F[Unit] = state
.modify(s =>
Expand Down Expand Up @@ -154,7 +158,7 @@ object FS2Channel {
}
protected def reportError(params: Option[Payload], error: ProtocolError, method: String): F[Unit] = ???
protected def getEndpoint(method: String): F[Option[Endpoint[F]]] = state.get.map(_.getEndpoint(method))
protected def sendMessage(message: Message): F[Unit] = queue.offer(Codec.encode(message))
protected def sendMessage(message: Message): F[Unit] = queue.offer(message)

protected def nextCallId(): F[CallId] = state.modify(_.nextCallId)
protected def createPromise[A](callId: CallId): F[(Try[A] => F[Unit], () => F[A])] = Deferred[F, Try[A]].map {
Expand Down
11 changes: 11 additions & 0 deletions fs2/src/jsonrpclib/fs2/lsp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,26 @@ import fs2.Chunk
import fs2.Stream
import fs2.Pipe
import jsonrpclib.Payload
import jsonrpclib.Codec

import java.nio.charset.Charset
import java.nio.charset.StandardCharsets
import jsonrpclib.Message
import jsonrpclib.ProtocolError

object lsp {

def encodeMessages[F[_]]: Pipe[F, Message, Byte] =
(_: Stream[F, Message]).map(Codec.encode(_)).through(encodePayloads)

def encodePayloads[F[_]]: Pipe[F, Payload, Byte] =
(_: Stream[F, Payload]).map(writeChunk).flatMap(Stream.chunk(_))

def decodeMessages[F[_]: MonadThrow]: Pipe[F, Byte, Either[ProtocolError, Message]] =
(_: Stream[F, Byte]).through(decodePayloads).map { payload =>
Codec.decode[Message](Some(payload))
}

/** Split a stream of bytes into payloads by extracting each frame based on information contained in the headers.
*
* See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#contentPart
Expand Down

0 comments on commit 1108514

Please sign in to comment.