Skip to content

Commit

Permalink
turn FrontEndMessage into a base trait
Browse files Browse the repository at this point in the history
  • Loading branch information
tpolecat committed Jul 9, 2020
1 parent 12e5237 commit 8dbb572
Show file tree
Hide file tree
Showing 15 changed files with 94 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ object BufferedMessageSocket {
new AbstractMessageSocket[F] with BufferedMessageSocket[F] {

override def receive: F[BackendMessage] = queue.dequeue1
override def send[A: FrontendMessage](a: A): F[Unit] = ms.send(a)
override def send(message: FrontendMessage): F[Unit] = ms.send(message)
override def transactionStatus: SignallingRef[F, TransactionStatus] = xaSig
override def parameters: SignallingRef[F, Map[String, String]] = paSig
override def backendKeyData: Deferred[F, BackendKeyData] = bkSig
Expand Down
10 changes: 5 additions & 5 deletions modules/core/src/main/scala/net/MessageSocket.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ trait MessageSocket[F[_]] {
def receive: F[BackendMessage]

/** Send the specified message. */
def send[A: FrontendMessage](a: A): F[Unit]
def send(message: FrontendMessage): F[Unit]

/** Destructively read the last `n` messages from the circular buffer. */
def history(max: Int): F[List[Either[Any, Any]]]
Expand Down Expand Up @@ -62,11 +62,11 @@ object MessageSocket {
_ <- Sync[F].delay(println(s"${Console.GREEN}$msg${Console.RESET}")).whenA(debug)
} yield msg

override def send[A](a: A)(implicit ev: FrontendMessage[A]): F[Unit] =
override def send(message: FrontendMessage): F[Unit] =
for {
_ <- Sync[F].delay(println(s"${Console.YELLOW}$a${Console.RESET}")).whenA(debug)
_ <- bvs.write(ev.fullEncoder.encode(a).require)
_ <- cb.enqueue1(Left(a))
_ <- Sync[F].delay(println(s"${Console.YELLOW}$message${Console.RESET}")).whenA(debug)
_ <- bvs.write(message.encode)
_ <- cb.enqueue1(Left(message))
} yield ()

override def history(max: Int): F[List[Either[Any, Any]]] =
Expand Down
6 changes: 4 additions & 2 deletions modules/core/src/main/scala/net/message/Bind.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import scodec.codecs._


case class Bind(portal: String, statement: String, args: List[Option[String]])
extends TaggedFrontendMessage('B') {
def encodeBody = Bind.encoder.encode(this)
}

object Bind {

implicit val BindFrontendMessage: FrontendMessage[Bind] =
FrontendMessage.tagged('B') {
val encoder: Encoder[Bind] = {

// String - The name of the destination portal (an empty string selects the unnamed portal).
// String - The name of the source prepared statement (an empty string selects the unnamed
Expand Down
12 changes: 6 additions & 6 deletions modules/core/src/main/scala/net/message/Close.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
package skunk.net.message

import scodec.codecs._
import scodec.Encoder

sealed abstract case class Close(variant: Byte, name: String) {
sealed abstract case class Close(variant: Byte, name: String) extends TaggedFrontendMessage('C') {
override def toString: String = s"Close(${variant.toChar},$name)"
def encodeBody = Close.encoder.encode(this)
}

object Close {
Expand All @@ -18,11 +20,9 @@ object Close {
def portal(name: String): Close =
new Close('P', name) {}

implicit val DescribeFrontendMessage: FrontendMessage[Close] =
FrontendMessage.tagged('C') {
(byte ~ utf8z).contramap[Close] { d =>
d.variant ~ d.name
}
val encoder: Encoder[Close] =
(byte ~ utf8z).contramap[Close] { d =>
d.variant ~ d.name
}

}
10 changes: 5 additions & 5 deletions modules/core/src/main/scala/net/message/Describe.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
package skunk.net.message

import scodec.codecs._
import scodec.Encoder

sealed abstract case class Describe(variant: Byte, name: String) {
sealed abstract case class Describe(variant: Byte, name: String) extends TaggedFrontendMessage('D') {
override def toString = s"Describe(${variant.toChar}, $name)"
def encodeBody = Describe.encoder.encode(this)
}

object Describe {
Expand All @@ -18,9 +20,7 @@ object Describe {
def portal(name: String): Describe =
new Describe('P', name) {}

implicit val DescribeFrontendMessage: FrontendMessage[Describe] =
FrontendMessage.tagged('D') {
(byte ~ utf8z).contramap[Describe] { d => d.variant ~ d.name }
}
val encoder: Encoder[Describe] =
(byte ~ utf8z).contramap[Describe] { d => d.variant ~ d.name }

}
13 changes: 7 additions & 6 deletions modules/core/src/main/scala/net/message/Execute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
package skunk.net.message

import scodec.codecs._
import scodec.Encoder

case class Execute(portal: String, maxRows: Int)
case class Execute(portal: String, maxRows: Int) extends TaggedFrontendMessage('E') {
def encodeBody = Execute.encoder.encode(this)
}

object Execute {

implicit val ExecuteFrontendMessage: FrontendMessage[Execute] =
FrontendMessage.tagged('E') {
(utf8z ~ int32).contramap[Execute] { p =>
p.portal ~ p.maxRows
}
val encoder: Encoder[Execute] =
(utf8z ~ int32).contramap[Execute] { p =>
p.portal ~ p.maxRows
}

}
15 changes: 1 addition & 14 deletions modules/core/src/main/scala/net/message/Flush.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,4 @@

package skunk.net.message

import scodec.Attempt
import scodec.bits._
import scodec._

case object Flush {

implicit val FlushFrontendMessage: FrontendMessage[Flush.type] =
FrontendMessage.tagged('H') {
Encoder { _ =>
Attempt.Successful(BitVector.empty)
}
}

}
case object Flush extends ConstFrontendMessage('H')
52 changes: 23 additions & 29 deletions modules/core/src/main/scala/net/message/FrontendMessage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,32 @@ import cats.implicits._
import scodec._
import scodec.codecs._
import scodec.interop.cats._
import scodec.bits.BitVector

/** A typeclass for messages we send to the server. */
trait FrontendMessage[A] {

/** Payload encoder (only). */
def encoder: Encoder[A]

/** Full encoder that adds a tag (if any) and length prefix. */
def fullEncoder: Encoder[A]

sealed trait FrontendMessage {
protected def encodeBody: Attempt[BitVector]
def encode: BitVector
}

object FrontendMessage {

private def lengthPrefixed[A](e: Encoder[A]): Encoder[A] =
Encoder { (a: A) =>
for {
p <- e.encode(a)
l <- int32.encode((p.size / 8).toInt + 4)
} yield l ++ p
}

def tagged[A](tag: Byte)(enc: Encoder[A]): FrontendMessage[A] =
new FrontendMessage[A] {
override val encoder: Encoder[A] = enc
override val fullEncoder: Encoder[A] = Encoder(a => byte.encode(tag) |+| lengthPrefixed(enc).encode(a))
}
abstract class UntaggedFrontendMessage extends FrontendMessage {
final def encode: BitVector = {
for {
b <- encodeBody
l <- int32.encode(((b.size) / 8).toInt + 4)
} yield l |+| b
} .require
}

def untagged[A](enc: Encoder[A]): FrontendMessage[A] =
new FrontendMessage[A] {
override val encoder: Encoder[A] = enc
override val fullEncoder: Encoder[A] = lengthPrefixed(enc)
}
abstract class TaggedFrontendMessage(tag: Byte) extends FrontendMessage {
final def encode: BitVector = {
for {
t <- byte.encode(tag)
b <- encodeBody
l <- int32.encode((b.size / 8).toInt + 4)
} yield t |+| l |+| b
} .require
}

abstract class ConstFrontendMessage(tag: Byte) extends TaggedFrontendMessage(tag) {
final def encodeBody = Attempt.successful(BitVector.empty)
}
13 changes: 7 additions & 6 deletions modules/core/src/main/scala/net/message/Parse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
package skunk.net.message

import scodec.codecs._
import scodec.Encoder

case class Parse(name: String, sql: String, types: List[Int])
case class Parse(name: String, sql: String, types: List[Int]) extends TaggedFrontendMessage('P') {
def encodeBody = Parse.encoder.encode(this)
}

object Parse {

implicit val ParseFrontendMessage: FrontendMessage[Parse] =
FrontendMessage.tagged('P') {
(utf8z ~ utf8z ~ int16 ~ list(int32)).contramap[Parse] { p =>
p.name ~ p.sql ~ p.types.length ~ p.types
}
val encoder: Encoder[Parse] =
(utf8z ~ utf8z ~ int16 ~ list(int32)).contramap[Parse] { p =>
p.name ~ p.sql ~ p.types.length ~ p.types
}

}
13 changes: 7 additions & 6 deletions modules/core/src/main/scala/net/message/PasswordMessage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@

package skunk.net.message
import java.security.MessageDigest
import scodec.Encoder

// import scodec.codecs._

case class PasswordMessage(password: String)
sealed abstract case class PasswordMessage(password: String) extends TaggedFrontendMessage('p') {
def encodeBody = PasswordMessage.encoder.encode(this)
}

object PasswordMessage {

implicit val PasswordMessageFrontendMessage: FrontendMessage[PasswordMessage] =
FrontendMessage.tagged('p') {
utf8z.contramap[PasswordMessage](_.password)
}
val encoder: Encoder[PasswordMessage] =
utf8z.contramap[PasswordMessage](_.password)

// See https://www.postgresql.org/docs/9.6/protocol-flow.html#AEN113418
// and https://github.com/pgjdbc/pgjdbc/blob/master/pgjdbc/src/main/java/org/postgresql/util/MD5Digest.java
Expand All @@ -38,7 +39,7 @@ object PasswordMessage {
hex = "0" + hex

// Done
PasswordMessage("md5" + hex)
new PasswordMessage("md5" + hex) {}

}

Expand Down
16 changes: 8 additions & 8 deletions modules/core/src/main/scala/net/message/Query.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ import scodec.Attempt
import scodec.bits._
import scodec._

case class Query(sql: String)
case class Query(sql: String) extends TaggedFrontendMessage('Q') {
def encodeBody = Query.encoder.encode(this)
}

object Query {

implicit val QueryFrontendMessage: FrontendMessage[Query] =
FrontendMessage.tagged('Q') {
Encoder { q =>
val barr = q.sql.getBytes("UTF8")
val barrʹ = java.util.Arrays.copyOf(barr, barr.length + 1) // add NUL
Attempt.Successful(BitVector(barrʹ))
}
val encoder: Encoder[Query] =
Encoder { q =>
val barr = q.sql.getBytes("UTF8")
val barrʹ = java.util.Arrays.copyOf(barr, barr.length + 1) // add NUL
Attempt.Successful(BitVector(barrʹ))
}

}
31 changes: 16 additions & 15 deletions modules/core/src/main/scala/net/message/StartupMessage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import scodec._
import scodec.codecs._

// TODO: SUPPORT OTHER PARAMETERS
case class StartupMessage(user: String, database: String) {
case class StartupMessage(user: String, database: String) extends UntaggedFrontendMessage {

def encodeBody = StartupMessage.encoder.encode(this)

// HACK: we will take a plist eventually
val properties: Map[String, String] =
Expand All @@ -26,24 +28,23 @@ object StartupMessage {
"client_encoding" -> "UTF8",
)

implicit val StartupMessageFrontendMessage: FrontendMessage[StartupMessage] =
FrontendMessage.untagged {
val encoder: Encoder[StartupMessage] = {

def pair(key: String): Codec[String] =
utf8z.applied(key) ~> utf8z
def pair(key: String): Codec[String] =
utf8z.applied(key) ~> utf8z

val version: Codec[Unit] =
int32.applied(196608)
val version: Codec[Unit] =
int32.applied(196608)

// After user and database we have a null-terminated list of fixed key-value pairs, which
// specify connection properties that affect serialization and are REQUIRED by Skunk.
val tail: Codec[Unit] =
ConnectionProperties.foldRight(byte.applied(0)) { case ((k, v), e) => pair(k).applied(v) <~ e}
// After user and database we have a null-terminated list of fixed key-value pairs, which
// specify connection properties that affect serialization and are REQUIRED by Skunk.
val tail: Codec[Unit] =
ConnectionProperties.foldRight(byte.applied(0)) { case ((k, v), e) => pair(k).applied(v) <~ e}

(version ~> pair("user") ~ pair("database") <~ tail)
.asEncoder
.contramap(m => m.user ~ m.database)
(version ~> pair("user") ~ pair("database") <~ tail)
.asEncoder
.contramap(m => m.user ~ m.database)

}
}

}
15 changes: 1 addition & 14 deletions modules/core/src/main/scala/net/message/Sync.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,4 @@

package skunk.net.message

import scodec.Attempt
import scodec.bits._
import scodec._

case object Sync {

implicit val SyncFrontendMessage: FrontendMessage[Sync.type] =
FrontendMessage.tagged('S') {
Encoder { _ =>
Attempt.Successful(BitVector.empty)
}
}

}
case object Sync extends ConstFrontendMessage('S')
12 changes: 1 addition & 11 deletions modules/core/src/main/scala/net/message/Terminate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,4 @@

package skunk.net.message

import scodec._
import scodec.bits._

case object Terminate {

implicit val TerminateFrontendMessage: FrontendMessage[Terminate.type] =
FrontendMessage.tagged('X') {
Encoder(_ => Attempt.successful(BitVector.empty))
}

}
case object Terminate extends ConstFrontendMessage('X')
4 changes: 2 additions & 2 deletions modules/core/src/main/scala/net/protocol/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import skunk.util.Origin
def receive[F[_]](implicit ev: MessageSocket[F]): F[BackendMessage] =
ev.receive

def send[F[_], A: FrontendMessage](a: A)(implicit ev: MessageSocket[F]): F[Unit] =
ev.send(a)
def send[F[_]](message: FrontendMessage)(implicit ev: MessageSocket[F]): F[Unit] =
ev.send(message)

def history[F[_]](max: Int)(implicit ev: MessageSocket[F]): F[List[Either[Any, Any]]] =
ev.history(max)
Expand Down

0 comments on commit 8dbb572

Please sign in to comment.