Skip to content

Commit f61f8ea

Browse files
committed
Expose underlying connection errors to user of AsyncPgConnection
1 parent 1e18b37 commit f61f8ea

File tree

1 file changed

+38
-5
lines changed

1 file changed

+38
-5
lines changed

src/pg/mod.rs

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use futures_util::future::BoxFuture;
2121
use futures_util::stream::{BoxStream, TryStreamExt};
2222
use futures_util::{Future, FutureExt, StreamExt};
2323
use std::borrow::Cow;
24+
use std::ops::DerefMut;
2425
use std::sync::Arc;
2526
use tokio::sync::Mutex;
2627
use tokio_postgres::types::ToSql;
@@ -102,6 +103,7 @@ pub struct AsyncPgConnection {
102103
stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
103104
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
104105
metadata_cache: Arc<Mutex<PgMetadataCache>>,
106+
error_receiver: Arc<Mutex<tokio::sync::oneshot::Receiver<diesel::result::Error>>>,
105107
}
106108

107109
#[async_trait::async_trait]
@@ -124,12 +126,19 @@ impl AsyncConnection for AsyncPgConnection {
124126
let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
125127
.await
126128
.map_err(ErrorHelper)?;
127-
tokio::spawn(async move {
128-
if let Err(e) = connection.await {
129-
eprintln!("connection error: {e}");
129+
// If there is a connection error, we capture it in this channel and make when
130+
// the user next calls one of the functions on the connection in this trait, we
131+
// return the error instead of the inner result.
132+
let (sender, receiver) = tokio::sync::oneshot::channel();
133+
tokio::spawn(async {
134+
if let Err(connection_error) = connection.await {
135+
let connection_error = diesel::result::Error::from(ErrorHelper(connection_error));
136+
if let Err(send_error) = sender.send(connection_error) {
137+
eprintln!("Failed to send connection error through channel, connection must have been dropped: {}", send_error);
138+
}
130139
}
131140
});
132-
Self::try_from(client).await
141+
Self::try_from_with_error_receiver(client, receiver).await
133142
}
134143

135144
fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
@@ -270,11 +279,24 @@ impl AsyncPgConnection {
270279

271280
/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
272281
pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
282+
// We create a dummy receiver here. If the user is calling this, they have
283+
// created their own client and connection and are handling any error in
284+
// the latter themselves.
285+
Self::try_from_with_error_receiver(conn, tokio::sync::oneshot::channel().1).await
286+
}
287+
288+
/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
289+
/// and a [`tokio::sync::oneshot::Receiver`] for receiving an error from the connection.
290+
async fn try_from_with_error_receiver(
291+
conn: tokio_postgres::Client,
292+
error_receiver: tokio::sync::oneshot::Receiver<diesel::result::Error>,
293+
) -> ConnectionResult<Self> {
273294
let mut conn = Self {
274295
conn: Arc::new(conn),
275296
stmt_cache: Arc::new(Mutex::new(StmtCache::new())),
276297
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
277298
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
299+
error_receiver: Arc::new(Mutex::new(error_receiver)),
278300
};
279301
conn.set_config_options()
280302
.await
@@ -340,7 +362,7 @@ impl AsyncPgConnection {
340362
let metadata_cache = self.metadata_cache.clone();
341363
let tm = self.transaction_state.clone();
342364

343-
async move {
365+
let f = async move {
344366
let sql = sql?;
345367
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
346368
collect_bind_result?;
@@ -411,6 +433,17 @@ impl AsyncPgConnection {
411433
let res = callback(raw_connection, stmt.clone(), binds).await;
412434
let mut tm = tm.lock().await;
413435
update_transaction_manager_status(res, &mut tm)
436+
};
437+
438+
let er = self.error_receiver.clone();
439+
async move {
440+
let mut error_receiver = er.lock().await;
441+
tokio::select! {
442+
// If there is an error from the connection in the channel, return that.
443+
error = error_receiver.deref_mut() => Err(error.unwrap()),
444+
// Otherwise return the inner result.
445+
res = f => res,
446+
}
414447
}
415448
.boxed()
416449
}

0 commit comments

Comments
 (0)