Skip to content

Support for execution of any single or multi query statements with discarded completions/rows #1023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions modules/core/shared/src/main/scala/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ trait Session[F[_]] {
@deprecated("Use execute(command)(args) instead of execute(command, args)", "0.6")
def execute[A](command: Command[A], args: A)(implicit ev: DummyImplicit): F[Completion] = execute(command)(args)

/**
* Execute any non-parameterized statement containing single or multi-query statements,
* discarding returned completions and rows.
*/
def executeDiscard(statement: Statement[Void]): F[Unit]

/**
* Prepares then caches a query, yielding a `PreparedQuery` which can be executed multiple
* times with different arguments.
Expand Down Expand Up @@ -622,6 +628,9 @@ object Session {

override def execute(command: Command[Void]): F[Completion] =
proto.execute(command)

override def executeDiscard(statement: Statement[Void]): F[Unit] =
proto.executeDiscard(statement)

override def channel(name: Identifier): Channel[F, String, String] =
Channel.fromNameAndProtocol(name, proto)
Expand Down Expand Up @@ -701,6 +710,8 @@ object Session {

override def execute(command: Command[Void]): G[Completion] = fk(outer.execute(command))

override def executeDiscard(statement: Statement[Void]): G[Unit] = fk(outer.executeDiscard(statement))

override def execute[A](query: Query[Void,A]): G[List[A]] = fk(outer.execute(query))

override def option[A](query: Query[Void,A]): G[Option[A]] = fk(outer.option(query))
Expand Down
8 changes: 8 additions & 0 deletions modules/core/shared/src/main/scala/net/Protocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ trait Protocol[F[_]] {
*/
def execute[A](query: Query[Void, A], ty: Typer): F[List[A]]

/**
* Execute any non-parameterized statement containing single or multi-query statements,
* discarding returned completions and rows.
*/
def executeDiscard(statement: Statement[Void]): F[Unit]

/**
* Initiate the session. This must be the first thing you do. This is very basic at the moment.
*/
Expand Down Expand Up @@ -246,6 +252,8 @@ object Protocol {
override def execute[B](query: Query[Void, B], ty: Typer): F[List[B]] =
protocol.Query[F](redactionStrategy).apply(query, ty)

override def executeDiscard(statement: Statement[Void]): F[Unit] = protocol.Query[F](redactionStrategy).applyDiscard(statement)

override def startup(user: String, database: String, password: Option[String], parameters: Map[String, String]): F[Unit] =
protocol.Startup[F].apply(user, database, password, parameters)

Expand Down
42 changes: 40 additions & 2 deletions modules/core/shared/src/main/scala/net/protocol/Query.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import skunk.exception.EmptyStatementException
trait Query[F[_]] {
def apply(command: Command[Void]): F[Completion]
def apply[B](query: skunk.Query[Void, B], ty: Typer): F[List[B]]
def applyDiscard(statement: Statement[Void]): F[Unit]
}

object Query {
Expand Down Expand Up @@ -212,8 +213,45 @@ object Query {

}

// Finish up any single or multi-query statement, discard returned completions and/or rows
// Fail with first encountered error
def finishUpDiscard(stmt: Statement[_], error: Option[SkunkException]): F[Unit] =
flatExpect {
case ReadyForQuery(_) => error match {
case None => ().pure[F]
case Some(e) => e.raiseError[F, Unit]
}

case RowDescription(_) | RowData(_) | CommandComplete(_) | EmptyQueryResponse | NoticeResponse(_) =>
finishUpDiscard(stmt, error)

case ErrorResponse(info) =>
error match {
case None =>
for {
hi <- history(Int.MaxValue)
err = new PostgresErrorException(stmt.sql, Some(stmt.origin), info, hi)
c <- finishUpDiscard(stmt, Some(err))
} yield c
case _ => finishUpDiscard(stmt, error)
}

}
// We don't support COPY FROM STDIN yet but we do need to be able to clean up if a user
// tries it.
case CopyInResponse(_) =>
send(CopyFail) *>
expect { case ErrorResponse(_) => } *>
finishUpDiscard(stmt, error.orElse(new CopyNotSupportedException(stmt).some))

}
case CopyOutResponse(_) =>
finishCopyOut *> finishUpDiscard(stmt, error.orElse(new CopyNotSupportedException(stmt).some))
}

override def applyDiscard(statement: Statement[Void]): F[Unit] =
exchange("query") { (span: Span[F]) =>
span.addAttribute(
Attribute("command.sql", statement.sql)
) *> send(QueryMessage(statement.sql)) *> finishUpDiscard(statement, None)
}
}
}
51 changes: 46 additions & 5 deletions modules/tests/shared/src/test/scala/MultipleStatementsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,63 @@ import cats.syntax.all._
import skunk._
import skunk.implicits._
import skunk.codec.all._
import skunk.exception.SkunkException
import skunk.exception.{SkunkException, CopyNotSupportedException, PostgresErrorException}
import cats.effect.IO
import skunk.exception.PostgresErrorException

class MultipleStatementsTest extends SkunkTest {

val statements: List[(Query[Void,Int], Command[Void])] =
val statements: List[(Query[Void,Int], Command[Void], Statement[Void])] =
List("select 1","commit","copy country from stdin","copy country to stdout") // one per protocol
.permutations
.toList
.map { ss => ss.intercalate(";") }
.map { s => (sql"#$s".query(int4), sql"#$s".command) }
.map { s => (sql"#$s".query(int4), sql"#$s".command, sql"#$s".command) }

statements.foreach { case (q, c) =>
statements.foreach { case (q, c, any) =>
sessionTest(s"query: ${q.sql}") { s => s.execute(q).assertFailsWith[SkunkException] *> s.assertHealthy }
sessionTest(s"command: ${c.sql}") { s => s.execute(c).assertFailsWith[SkunkException] *> s.assertHealthy }
sessionTest(s"any discarded: ${any.sql}") { s => s.executeDiscard(any).assertFailsWith[CopyNotSupportedException] *> s.assertHealthy }
}

// statements with no errors
List(
"""CREATE FUNCTION do_something() RETURNS integer AS $$ BEGIN RETURN 1; END; $$ LANGUAGE plpgsql;
SELECT do_something();
DROP FUNCTION do_something
""",
"""ALTER TABLE city ADD COLUMN idd SERIAL;
SELECT setval('city_idd_seq', max(id)) FROM city;
ALTER TABLE city DROP COLUMN idd""",
"/* empty */")
.permutations
.toList
.map(s => sql"#${s.intercalate(";")}".command)
.foreach { stmt =>
sessionTest(s"discarded no errors: ${stmt.sql}") { s =>
s.executeDiscard(stmt) *> s.assertHealthy
}
}

// statements with different errors
{
val copy = "copy country from stdin"
val conflict = "create table country()"

Vector("select 1","commit",conflict,copy)
.permutations
.toVector
.foreach { statements =>
val stmt = sql"#${statements.intercalate(";")}".command

if (statements.indexOf(conflict) < statements.indexOf(copy))
sessionTest(s"discarded with postgres error: ${stmt.sql}")(s =>
s.executeDiscard(stmt).assertFailsWith[PostgresErrorException] *> s.assertHealthy
)
else
sessionTest(s"discarded with unsupported error: ${stmt.sql}")(s =>
s.executeDiscard(stmt).assertFailsWith[CopyNotSupportedException] *> s.assertHealthy
)
}
}

sessionTest("extended query (postgres raises an error here)") { s =>
Expand Down