@@ -21,6 +21,7 @@ use futures_util::future::BoxFuture;
21
21
use futures_util:: stream:: { BoxStream , TryStreamExt } ;
22
22
use futures_util:: { Future , FutureExt , StreamExt } ;
23
23
use std:: borrow:: Cow ;
24
+ use std:: ops:: DerefMut ;
24
25
use std:: sync:: Arc ;
25
26
use tokio:: sync:: Mutex ;
26
27
use tokio_postgres:: types:: ToSql ;
@@ -102,6 +103,7 @@ pub struct AsyncPgConnection {
102
103
stmt_cache : Arc < Mutex < StmtCache < diesel:: pg:: Pg , Statement > > > ,
103
104
transaction_state : Arc < Mutex < AnsiTransactionManager > > ,
104
105
metadata_cache : Arc < Mutex < PgMetadataCache > > ,
106
+ error_receiver : Arc < Mutex < tokio:: sync:: oneshot:: Receiver < diesel:: result:: Error > > > ,
105
107
}
106
108
107
109
#[ async_trait:: async_trait]
@@ -124,12 +126,19 @@ impl AsyncConnection for AsyncPgConnection {
124
126
let ( client, connection) = tokio_postgres:: connect ( database_url, tokio_postgres:: NoTls )
125
127
. await
126
128
. map_err ( ErrorHelper ) ?;
127
- tokio:: spawn ( async move {
128
- if let Err ( e) = connection. await {
129
- eprintln ! ( "connection error: {e}" ) ;
129
+ // If there is a connection error, we capture it in this channel and make when
130
+ // the user next calls one of the functions on the connection in this trait, we
131
+ // return the error instead of the inner result.
132
+ let ( sender, receiver) = tokio:: sync:: oneshot:: channel ( ) ;
133
+ tokio:: spawn ( async {
134
+ if let Err ( connection_error) = connection. await {
135
+ let connection_error = diesel:: result:: Error :: from ( ErrorHelper ( connection_error) ) ;
136
+ if let Err ( send_error) = sender. send ( connection_error) {
137
+ eprintln ! ( "Failed to send connection error through channel, connection must have been dropped: {}" , send_error) ;
138
+ }
130
139
}
131
140
} ) ;
132
- Self :: try_from ( client) . await
141
+ Self :: try_from_with_error_receiver ( client, receiver ) . await
133
142
}
134
143
135
144
fn load < ' conn , ' query , T > ( & ' conn mut self , source : T ) -> Self :: LoadFuture < ' conn , ' query >
@@ -270,11 +279,24 @@ impl AsyncPgConnection {
270
279
271
280
/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
272
281
pub async fn try_from ( conn : tokio_postgres:: Client ) -> ConnectionResult < Self > {
282
+ // We create a dummy receiver here. If the user is calling this, they have
283
+ // created their own client and connection and are handling any error in
284
+ // the latter themselves.
285
+ Self :: try_from_with_error_receiver ( conn, tokio:: sync:: oneshot:: channel ( ) . 1 ) . await
286
+ }
287
+
288
+ /// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
289
+ /// and a [`tokio::sync::oneshot::Receiver`] for receiving an error from the connection.
290
+ async fn try_from_with_error_receiver (
291
+ conn : tokio_postgres:: Client ,
292
+ error_receiver : tokio:: sync:: oneshot:: Receiver < diesel:: result:: Error > ,
293
+ ) -> ConnectionResult < Self > {
273
294
let mut conn = Self {
274
295
conn : Arc :: new ( conn) ,
275
296
stmt_cache : Arc :: new ( Mutex :: new ( StmtCache :: new ( ) ) ) ,
276
297
transaction_state : Arc :: new ( Mutex :: new ( AnsiTransactionManager :: default ( ) ) ) ,
277
298
metadata_cache : Arc :: new ( Mutex :: new ( PgMetadataCache :: new ( ) ) ) ,
299
+ error_receiver : Arc :: new ( Mutex :: new ( error_receiver) ) ,
278
300
} ;
279
301
conn. set_config_options ( )
280
302
. await
@@ -340,7 +362,7 @@ impl AsyncPgConnection {
340
362
let metadata_cache = self . metadata_cache . clone ( ) ;
341
363
let tm = self . transaction_state . clone ( ) ;
342
364
343
- async move {
365
+ let f = async move {
344
366
let sql = sql?;
345
367
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
346
368
collect_bind_result?;
@@ -411,6 +433,17 @@ impl AsyncPgConnection {
411
433
let res = callback ( raw_connection, stmt. clone ( ) , binds) . await ;
412
434
let mut tm = tm. lock ( ) . await ;
413
435
update_transaction_manager_status ( res, & mut tm)
436
+ } ;
437
+
438
+ let er = self . error_receiver . clone ( ) ;
439
+ async move {
440
+ let mut error_receiver = er. lock ( ) . await ;
441
+ tokio:: select! {
442
+ // If there is an error from the connection in the channel, return that.
443
+ error = error_receiver. deref_mut( ) => Err ( error. unwrap( ) ) ,
444
+ // Otherwise return the inner result.
445
+ res = f => res,
446
+ }
414
447
}
415
448
. boxed ( )
416
449
}
0 commit comments