Skip to content

Fix transaction not being rolled back on TransactionBuilder::start() Future dropped before completion #1127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 3 additions & 39 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::codec::{BackendMessages, FrontendMessage};
use crate::codec::BackendMessages;
use crate::config::SslMode;
use crate::connection::{Request, RequestMessages};
use crate::copy_out::CopyOutStream;
Expand All @@ -21,7 +21,7 @@ use fallible_iterator::FallibleIterator;
use futures_channel::mpsc;
use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt};
use parking_lot::Mutex;
use postgres_protocol::message::{backend::Message, frontend};
use postgres_protocol::message::backend::Message;
use postgres_types::BorrowToSql;
use std::collections::HashMap;
use std::fmt;
Expand Down Expand Up @@ -469,43 +469,7 @@ impl Client {
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
struct RollbackIfNotDone<'me> {
client: &'me Client,
done: bool,
}

impl<'a> Drop for RollbackIfNotDone<'a> {
fn drop(&mut self) {
if self.done {
return;
}

let buf = self.client.inner().with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
buf.split().freeze()
});
let _ = self
.client
.inner()
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}

// This is done, as `Future` created by this method can be dropped after
// `RequestMessages` is synchronously send to the `Connection` by
// `batch_execute()`, but before `Responses` is asynchronously polled to
// completion. In that case `Transaction` won't be created and thus
// won't be rolled back.
{
let mut cleaner = RollbackIfNotDone {
client: self,
done: false,
};
self.batch_execute("BEGIN").await?;
cleaner.done = true;
}

Ok(Transaction::new(self))
Transaction::begin(self).await
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be simpler to just have this call self.transaction_builder().start().await and have the cancellation logic in one place.

}

/// Returns a builder for a transaction with custom settings.
Expand Down
68 changes: 66 additions & 2 deletions tokio-postgres/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::types::{BorrowToSql, ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{
bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row,
bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, IsolationLevel, Portal, Row,
SimpleQueryMessage, Statement, ToStatement,
};
use bytes::Buf;
Expand Down Expand Up @@ -56,14 +56,78 @@ impl<'a> Drop for Transaction<'a> {
}

impl<'a> Transaction<'a> {
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
fn new(client: &'a mut Client) -> Transaction<'a> {
Transaction {
client,
savepoint: None,
done: false,
}
}

pub(crate) async fn begin(client: &'a mut Client) -> Result<Transaction<'a>, Error> {
let transaction = Transaction::new(client);

transaction.client.batch_execute("BEGIN").await?;

Ok(transaction)
}

pub(crate) async fn start(
client: &'a mut Client,
isolation_level: Option<IsolationLevel>,
read_only: Option<bool>,
deferrable: Option<bool>,
) -> Result<Transaction<'a>, Error> {
let transaction = Transaction::new(client);

let mut query = "START TRANSACTION".to_string();
let mut first = true;

if let Some(level) = isolation_level {
first = false;

query.push_str(" ISOLATION LEVEL ");
let level = match level {
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
IsolationLevel::ReadCommitted => "READ COMMITTED",
IsolationLevel::RepeatableRead => "REPEATABLE READ",
IsolationLevel::Serializable => "SERIALIZABLE",
};
query.push_str(level);
}

if let Some(read_only) = read_only {
if !first {
query.push(',');
}
first = false;

let s = if read_only {
" READ ONLY"
} else {
" READ WRITE"
};
query.push_str(s);
}

if let Some(deferrable) = deferrable {
if !first {
query.push(',');
}

let s = if deferrable {
" DEFERRABLE"
} else {
" NOT DEFERRABLE"
};
query.push_str(s);
}

transaction.client.batch_execute(&query).await?;

Ok(transaction)
}

/// Consumes the transaction, committing all changes made within it.
pub async fn commit(mut self) -> Result<(), Error> {
self.done = true;
Expand Down
53 changes: 7 additions & 46 deletions tokio-postgres/src/transaction_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,51 +63,12 @@ impl<'a> TransactionBuilder<'a> {
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn start(self) -> Result<Transaction<'a>, Error> {
let mut query = "START TRANSACTION".to_string();
let mut first = true;

if let Some(level) = self.isolation_level {
first = false;

query.push_str(" ISOLATION LEVEL ");
let level = match level {
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
IsolationLevel::ReadCommitted => "READ COMMITTED",
IsolationLevel::RepeatableRead => "REPEATABLE READ",
IsolationLevel::Serializable => "SERIALIZABLE",
};
query.push_str(level);
}

if let Some(read_only) = self.read_only {
if !first {
query.push(',');
}
first = false;

let s = if read_only {
" READ ONLY"
} else {
" READ WRITE"
};
query.push_str(s);
}

if let Some(deferrable) = self.deferrable {
if !first {
query.push(',');
}

let s = if deferrable {
" DEFERRABLE"
} else {
" NOT DEFERRABLE"
};
query.push_str(s);
}

self.client.batch_execute(&query).await?;

Ok(Transaction::new(self.client))
Transaction::start(
self.client,
self.isolation_level,
self.read_only,
self.deferrable,
)
.await
}
}
24 changes: 24 additions & 0 deletions tokio-postgres/tests/test/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,30 @@ async fn transaction_future_cancellation() {
}
}

#[tokio::test]
async fn start_transaction_future_cancellation() {
let mut client = connect("user=postgres").await;

for i in 0.. {
let done = {
let txn = client.build_transaction().start();
let fut = Cancellable {
fut: txn,
polls_left: i,
};
fut.await
.map(|res| res.expect("transaction failed"))
.is_some()
};

assert!(!in_transaction(&client).await);

if done {
break;
}
}
}

#[tokio::test]
async fn transaction_commit_future_cancellation() {
let mut client = connect("user=postgres").await;
Expand Down