Skip to content

Commit 6c73e23

Browse files
committed
Better handling of the postgres connection background task
* Propagate errors to the user * Cancel the task if we drop the connection
1 parent 287e380 commit 6c73e23

File tree

2 files changed

+112
-42
lines changed

2 files changed

+112
-42
lines changed

src/pg/error_helper.rs

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
use std::error::Error;
2+
use std::sync::Arc;
3+
14
use diesel::ConnectionError;
25

36
pub(super) struct ErrorHelper(pub(super) tokio_postgres::Error);
@@ -10,40 +13,46 @@ impl From<ErrorHelper> for ConnectionError {
1013

1114
impl From<ErrorHelper> for diesel::result::Error {
1215
fn from(ErrorHelper(postgres_error): ErrorHelper) -> Self {
13-
use diesel::result::DatabaseErrorKind::*;
14-
use tokio_postgres::error::SqlState;
16+
from_tokio_postgres_error(Arc::new(postgres_error))
17+
}
18+
}
1519

16-
match postgres_error.code() {
17-
Some(code) => {
18-
let kind = match *code {
19-
SqlState::UNIQUE_VIOLATION => UniqueViolation,
20-
SqlState::FOREIGN_KEY_VIOLATION => ForeignKeyViolation,
21-
SqlState::T_R_SERIALIZATION_FAILURE => SerializationFailure,
22-
SqlState::READ_ONLY_SQL_TRANSACTION => ReadOnlyTransaction,
23-
SqlState::NOT_NULL_VIOLATION => NotNullViolation,
24-
SqlState::CHECK_VIOLATION => CheckViolation,
25-
_ => Unknown,
26-
};
20+
pub(super) fn from_tokio_postgres_error(
21+
postgres_error: Arc<tokio_postgres::Error>,
22+
) -> diesel::result::Error {
23+
use diesel::result::DatabaseErrorKind::*;
24+
use tokio_postgres::error::SqlState;
2725

28-
diesel::result::Error::DatabaseError(
29-
kind,
30-
Box::new(PostgresDbErrorWrapper(
31-
postgres_error
32-
.into_source()
33-
.and_then(|e| e.downcast::<tokio_postgres::error::DbError>().ok())
34-
.expect("It's a db error, because we've got a SQLState code above"),
35-
)) as _,
36-
)
37-
}
38-
None => diesel::result::Error::DatabaseError(
39-
UnableToSendCommand,
40-
Box::new(postgres_error.to_string()),
41-
),
26+
match postgres_error.code() {
27+
Some(code) => {
28+
let kind = match *code {
29+
SqlState::UNIQUE_VIOLATION => UniqueViolation,
30+
SqlState::FOREIGN_KEY_VIOLATION => ForeignKeyViolation,
31+
SqlState::T_R_SERIALIZATION_FAILURE => SerializationFailure,
32+
SqlState::READ_ONLY_SQL_TRANSACTION => ReadOnlyTransaction,
33+
SqlState::NOT_NULL_VIOLATION => NotNullViolation,
34+
SqlState::CHECK_VIOLATION => CheckViolation,
35+
_ => Unknown,
36+
};
37+
38+
diesel::result::Error::DatabaseError(
39+
kind,
40+
Box::new(PostgresDbErrorWrapper(
41+
postgres_error
42+
.source()
43+
.and_then(|e| e.downcast_ref::<tokio_postgres::error::DbError>().cloned())
44+
.expect("It's a db error, because we've got a SQLState code above"),
45+
)) as _,
46+
)
4247
}
48+
None => diesel::result::Error::DatabaseError(
49+
UnableToSendCommand,
50+
Box::new(postgres_error.to_string()),
51+
),
4352
}
4453
}
4554

46-
struct PostgresDbErrorWrapper(Box<tokio_postgres::error::DbError>);
55+
struct PostgresDbErrorWrapper(tokio_postgres::error::DbError);
4756

4857
impl diesel::result::DatabaseErrorInformation for PostgresDbErrorWrapper {
4958
fn message(&self) -> &str {

src/pg/mod.rs

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,17 @@ use diesel::pg::{
1616
};
1717
use diesel::query_builder::bind_collector::RawBytesBindCollector;
1818
use diesel::query_builder::{AsQuery, QueryBuilder, QueryFragment, QueryId};
19+
use diesel::result::DatabaseErrorKind;
1920
use diesel::{ConnectionError, ConnectionResult, QueryResult};
2021
use futures_util::future::BoxFuture;
22+
use futures_util::future::Either;
2123
use futures_util::stream::{BoxStream, TryStreamExt};
24+
use futures_util::TryFutureExt;
2225
use futures_util::{Future, FutureExt, StreamExt};
2326
use std::borrow::Cow;
2427
use std::sync::Arc;
28+
use tokio::sync::broadcast;
29+
use tokio::sync::oneshot;
2530
use tokio::sync::Mutex;
2631
use tokio_postgres::types::ToSql;
2732
use tokio_postgres::types::Type;
@@ -102,12 +107,20 @@ pub struct AsyncPgConnection {
102107
stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
103108
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
104109
metadata_cache: Arc<Mutex<PgMetadataCache>>,
110+
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
111+
shutdown_channel: Option<oneshot::Sender<()>>,
105112
}
106113

107114
#[async_trait::async_trait]
108115
impl SimpleAsyncConnection for AsyncPgConnection {
109116
async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
110-
Ok(self.conn.batch_execute(query).await.map_err(ErrorHelper)?)
117+
let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
118+
let batch_execute = self
119+
.conn
120+
.batch_execute(query)
121+
.map_err(ErrorHelper)
122+
.map_err(Into::into);
123+
drive_future(connection_future, batch_execute).await
111124
}
112125
}
113126

@@ -124,29 +137,37 @@ impl AsyncConnection for AsyncPgConnection {
124137
let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
125138
.await
126139
.map_err(ErrorHelper)?;
140+
let (tx, rx) = tokio::sync::broadcast::channel(1);
141+
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
127142
tokio::spawn(async move {
128-
if let Err(e) = connection.await {
129-
eprintln!("connection error: {e}");
143+
match futures_util::future::select(shutdown_rx, connection).await {
144+
Either::Left(_) | Either::Right((Ok(_), _)) => {}
145+
Either::Right((Err(e), _)) => {
146+
let _ = tx.send(Arc::new(e));
147+
}
130148
}
131149
});
132-
Self::try_from(client).await
150+
151+
Self::setup(client, Some(rx), Some(shutdown_tx)).await
133152
}
134153

135154
fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
136155
where
137156
T: AsQuery + 'query,
138157
T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
139158
{
159+
let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
140160
let query = source.as_query();
141-
self.with_prepared_statement(query, |conn, stmt, binds| async move {
161+
let load_future = self.with_prepared_statement(query, |conn, stmt, binds| async move {
142162
let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;
143163

144164
Ok(res
145165
.map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
146166
.map_ok(PgRow::new)
147167
.boxed())
148-
})
149-
.boxed()
168+
});
169+
170+
drive_future(connection_future, load_future).boxed()
150171
}
151172

152173
fn execute_returning_count<'conn, 'query, T>(
@@ -156,7 +177,8 @@ impl AsyncConnection for AsyncPgConnection {
156177
where
157178
T: QueryFragment<Self::Backend> + QueryId + 'query,
158179
{
159-
self.with_prepared_statement(source, |conn, stmt, binds| async move {
180+
let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
181+
let execute = self.with_prepared_statement(source, |conn, stmt, binds| async move {
160182
let binds = binds
161183
.iter()
162184
.map(|b| b as &(dyn ToSql + Sync))
@@ -166,8 +188,8 @@ impl AsyncConnection for AsyncPgConnection {
166188
.await
167189
.map_err(ErrorHelper)?;
168190
Ok(res as usize)
169-
})
170-
.boxed()
191+
});
192+
drive_future(connection_future, execute).boxed()
171193
}
172194

173195
fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
@@ -182,15 +204,21 @@ impl AsyncConnection for AsyncPgConnection {
182204
}
183205
}
184206

207+
impl Drop for AsyncPgConnection {
208+
fn drop(&mut self) {
209+
if let Some(tx) = self.shutdown_channel.take() {
210+
let _ = tx.send(());
211+
}
212+
}
213+
}
214+
185215
#[inline(always)]
186216
fn update_transaction_manager_status<T>(
187217
query_result: QueryResult<T>,
188218
transaction_manager: &mut AnsiTransactionManager,
189219
) -> QueryResult<T> {
190-
if let Err(diesel::result::Error::DatabaseError(
191-
diesel::result::DatabaseErrorKind::SerializationFailure,
192-
_,
193-
)) = query_result
220+
if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) =
221+
query_result
194222
{
195223
transaction_manager
196224
.status
@@ -270,11 +298,21 @@ impl AsyncPgConnection {
270298

271299
/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
272300
pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
301+
Self::setup(conn, None, None).await
302+
}
303+
304+
async fn setup(
305+
conn: tokio_postgres::Client,
306+
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
307+
shutdown_channel: Option<oneshot::Sender<()>>,
308+
) -> ConnectionResult<Self> {
273309
let mut conn = Self {
274310
conn: Arc::new(conn),
275311
stmt_cache: Arc::new(Mutex::new(StmtCache::new())),
276312
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
277313
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
314+
connection_future,
315+
shutdown_channel,
278316
};
279317
conn.set_config_options()
280318
.await
@@ -470,6 +508,29 @@ async fn lookup_type(
470508
Ok((r.get(0), r.get(1)))
471509
}
472510

511+
async fn drive_future<R>(
512+
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
513+
client_future: impl Future<Output = Result<R, diesel::result::Error>>,
514+
) -> Result<R, diesel::result::Error> {
515+
if let Some(mut connection_future) = connection_future {
516+
let client_future = std::pin::pin!(client_future);
517+
let connection_future = std::pin::pin!(connection_future.recv());
518+
match futures_util::future::select(client_future, connection_future).await {
519+
Either::Left((res, _)) => res,
520+
// we got an error from the background task
521+
// return it to the user
522+
Either::Right((Ok(e), _)) => Err(self::error_helper::from_tokio_postgres_error(e)),
523+
// seems like the background thread died for whatever reason
524+
Either::Right((Err(e), _)) => Err(diesel::result::Error::DatabaseError(
525+
DatabaseErrorKind::UnableToSendCommand,
526+
Box::new(e.to_string()),
527+
)),
528+
}
529+
} else {
530+
client_future.await
531+
}
532+
}
533+
473534
#[cfg(any(
474535
feature = "deadpool",
475536
feature = "bb8",

0 commit comments

Comments
 (0)