diff --git a/modules/core/src/main/scala/net/protocol/Execute.scala b/modules/core/src/main/scala/net/protocol/Execute.scala index 639a97f0..18ce3d58 100644 --- a/modules/core/src/main/scala/net/protocol/Execute.scala +++ b/modules/core/src/main/scala/net/protocol/Execute.scala @@ -30,7 +30,7 @@ object Execute { _ <- send(ExecuteMessage(portal.id.value, 0)) _ <- send(Flush) c <- flatExpect { - case CommandComplete(c) => c.pure[F] + case CommandComplete(c) => sync *> expect { case ReadyForQuery(_) => c } // https://github.com/tpolecat/skunk/issues/210 case ErrorResponse(info) => syncAndFail[A](portal, info) } } yield c diff --git a/modules/tests/src/test/scala/CommandTest.scala b/modules/tests/src/test/scala/CommandTest.scala index 183b0d69..abf5c128 100644 --- a/modules/tests/src/test/scala/CommandTest.scala +++ b/modules/tests/src/test/scala/CommandTest.scala @@ -18,6 +18,7 @@ case object CommandTest extends SkunkTest { (int4 ~ varchar ~ bpchar(3) ~ varchar ~ int4).gimap[City] val Garin = City(5000, "Garin", "ARG", "Escobar", 11405) + val Garin2 = City(5001, "Garin2", "ARG", "Escobar", 11405) val insertCity: Command[City] = sql""" @@ -72,7 +73,7 @@ case object CommandTest extends SkunkTest { val createSchema: Command[Void] = sql""" - CREATE SCHEMA public_0 + CREATE SCHEMA IF NOT EXISTS public_0 """.command val dropSchema: Command[Void] = @@ -120,11 +121,11 @@ case object CommandTest extends SkunkTest { sessionTest("insert and delete record with contramapped command") { s => for { - c <- s.prepare(insertCity2).use(_.execute(Garin)) + c <- s.prepare(insertCity2).use(_.execute(Garin2)) _ <- assert("completion", c == Completion.Insert(1)) - c <- s.prepare(selectCity).use(_.unique(Garin.id)) - _ <- assert("read", c == Garin) - _ <- s.prepare(deleteCity).use(_.execute(Garin.id)) + c <- s.prepare(selectCity).use(_.unique(Garin2.id)) + _ <- assert("read", c == Garin2) + _ <- s.prepare(deleteCity).use(_.execute(Garin2.id)) _ <- s.assertHealthy } yield "ok" } diff --git a/modules/tests/src/test/scala/issue/210.scala b/modules/tests/src/test/scala/issue/210.scala new file mode 100644 index 00000000..30abfee0 --- /dev/null +++ b/modules/tests/src/test/scala/issue/210.scala @@ -0,0 +1,81 @@ +// Copyright (c) 2018-2020 by Rob Norris +// This software is licensed under the MIT License (MIT). +// For more information see LICENSE or https://opensource.org/licenses/MIT + +package tests.issue + +import cats.implicits._ +import skunk._ +import skunk.codec.all._ +import skunk.implicits._ +import tests.SkunkTest +import cats.effect._ +import cats.effect.concurrent.Deferred + +// https://github.com/tpolecat/skunk/issues/210 +case object Test210 extends SkunkTest { + + // a resource that creates and drops a table + def withPetsTable(s: Session[IO]): Resource[IO, Unit] = { + val alloc = s.execute(sql"CREATE TABLE IF NOT EXISTS Test210_pets (name varchar, age int2)".command).void + val free = s.execute(sql"DROP TABLE Test210_pets".command).void + Resource.make(alloc)(_ => free) + } + + // a data type + case class Pet(name: String, age: Short) + + // command to insert a pet + val insertOne: Command[Pet] = + sql"INSERT INTO Test210_pets VALUES ($varchar, $int2)" + .command + .gcontramap[Pet] + + // command to insert a specific list of Test210_pets + def insertMany(ps: List[Pet]): Command[ps.type] = { + val enc = (varchar ~ int2).gcontramap[Pet].values.list(ps) + sql"INSERT INTO Test210_pets VALUES $enc".command + } + + // query to select all Test210_pets + def selectAll: Query[Void, Pet] = + sql"SELECT name, age FROM Test210_pets" + .query(varchar ~ int2) + .gmap[Pet] + + // some sample data + val bob = Pet("Bob", 12) + val beatles = List(Pet("John", 2), Pet("George", 3), Pet("Paul", 6), Pet("Ringo", 3)) + + def doInserts(ready: Deferred[IO, Unit], done: Deferred[IO, Unit]): IO[Unit] = + session.flatTap(withPetsTable).use { s => + for { + _ <- s.prepare(insertOne).use(pc => pc.execute(Pet("Bob", 12))) + _ <- s.prepare(insertMany(beatles)).use(pc => pc.execute(beatles)) + _ <- ready.complete(()) + _ <- done.get // wait for main fiber to finish + } yield () + } + + val check: IO[Unit] = + session.use { s => + for { + ns <- s.execute(sql"select name from Test210_pets".query(varchar)) + _ <- assertEqual("names", ns, (bob :: beatles).map(_.name)) + } yield () + } + + test("issue/210") { + for { + ready <- Deferred[IO, Unit] + done <- Deferred[IO, Unit] + fib <- doInserts(ready, done).start // fork + _ <- ready.get // wait for forked fiber to say it's ready + _ <- check.guarantee { + // ensure the fork is cleaned up so our table gets deleted + done.complete(()) *> fib.join + } + } yield "ok" + } + +} \ No newline at end of file