Skip to content

Commit 9a3e9c0

Browse files
committed
Fix transction rollback on Future drop
- fixes drop of `Future` after synchronous message send but before `Transaction` object creation - follow-on to f6189a9 for `TransactionBuilder` and savepoint `Transaction::transaction` creation
1 parent 64caf4c commit 9a3e9c0

File tree

3 files changed

+32
-18
lines changed

3 files changed

+32
-18
lines changed

tokio-postgres/src/client.rs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ use crate::types::{Oid, ToSql, Type};
1313
#[cfg(feature = "runtime")]
1414
use crate::Socket;
1515
use crate::{
16-
copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error,
17-
Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
16+
copy_in, copy_out, prepare, query, simple_query, slice_iter, transaction::Savepoint,
17+
CancelToken, CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction,
18+
TransactionBuilder,
1819
};
1920
use bytes::{Buf, BytesMut};
2021
use fallible_iterator::FallibleIterator;
@@ -469,8 +470,17 @@ impl Client {
469470
///
470471
/// The transaction will roll back by default - use the `commit` method to commit it.
471472
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
473+
self.build_transaction().start().await
474+
}
475+
476+
pub(crate) async fn start_transaction_with_rollback(
477+
&mut self,
478+
query: &str,
479+
savepoint: Option<Savepoint>,
480+
) -> Result<Transaction<'_>, Error> {
472481
struct RollbackIfNotDone<'me> {
473482
client: &'me Client,
483+
savepoint: Option<&'me Savepoint>,
474484
done: bool,
475485
}
476486

@@ -480,8 +490,13 @@ impl Client {
480490
return;
481491
}
482492

493+
let query = if let Some(sp) = self.savepoint {
494+
format!("ROLLBACK TO {}", sp.name)
495+
} else {
496+
"ROLLBACK".to_string()
497+
};
483498
let buf = self.client.inner().with_buf(|buf| {
484-
frontend::query("ROLLBACK", buf).unwrap();
499+
frontend::query(&query, buf).unwrap();
485500
buf.split().freeze()
486501
});
487502
let _ = self
@@ -499,13 +514,14 @@ impl Client {
499514
{
500515
let mut cleaner = RollbackIfNotDone {
501516
client: self,
517+
savepoint: savepoint.as_ref(),
502518
done: false,
503519
};
504-
self.batch_execute("BEGIN").await?;
520+
self.batch_execute(query).await?;
505521
cleaner.done = true;
506522
}
507523

508-
Ok(Transaction::new(self))
524+
Ok(Transaction::new(self, savepoint))
509525
}
510526

511527
/// Returns a builder for a transaction with custom settings.

tokio-postgres/src/transaction.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ pub struct Transaction<'a> {
2828
}
2929

3030
/// A representation of a PostgreSQL database savepoint.
31-
struct Savepoint {
32-
name: String,
31+
pub(crate) struct Savepoint {
32+
pub(crate) name: String,
3333
depth: u32,
3434
}
3535

@@ -56,10 +56,10 @@ impl<'a> Drop for Transaction<'a> {
5656
}
5757

5858
impl<'a> Transaction<'a> {
59-
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
59+
pub(crate) fn new(client: &'a mut Client, savepoint: Option<Savepoint>) -> Transaction<'a> {
6060
Transaction {
6161
client,
62-
savepoint: None,
62+
savepoint,
6363
done: false,
6464
}
6565
}
@@ -298,13 +298,11 @@ impl<'a> Transaction<'a> {
298298
let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
299299
let name = name.unwrap_or_else(|| format!("sp_{}", depth));
300300
let query = format!("SAVEPOINT {}", name);
301-
self.batch_execute(&query).await?;
301+
let savepoint = Savepoint { name, depth };
302302

303-
Ok(Transaction {
304-
client: self.client,
305-
savepoint: Some(Savepoint { name, depth }),
306-
done: false,
307-
})
303+
self.client
304+
.start_transaction_with_rollback(&query, Some(savepoint))
305+
.await
308306
}
309307

310308
/// Returns a reference to the underlying `Client`.

tokio-postgres/src/transaction_builder.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ impl<'a> TransactionBuilder<'a> {
106106
query.push_str(s);
107107
}
108108

109-
self.client.batch_execute(&query).await?;
110-
111-
Ok(Transaction::new(self.client))
109+
self.client
110+
.start_transaction_with_rollback(&query, None)
111+
.await
112112
}
113113
}

0 commit comments

Comments
 (0)