@@ -16,12 +16,17 @@ use diesel::pg::{
16
16
} ;
17
17
use diesel:: query_builder:: bind_collector:: RawBytesBindCollector ;
18
18
use diesel:: query_builder:: { AsQuery , QueryBuilder , QueryFragment , QueryId } ;
19
+ use diesel:: result:: DatabaseErrorKind ;
19
20
use diesel:: { ConnectionError , ConnectionResult , QueryResult } ;
20
21
use futures_util:: future:: BoxFuture ;
22
+ use futures_util:: future:: Either ;
21
23
use futures_util:: stream:: { BoxStream , TryStreamExt } ;
24
+ use futures_util:: TryFutureExt ;
22
25
use futures_util:: { Future , FutureExt , StreamExt } ;
23
26
use std:: borrow:: Cow ;
24
27
use std:: sync:: Arc ;
28
+ use tokio:: sync:: broadcast;
29
+ use tokio:: sync:: oneshot;
25
30
use tokio:: sync:: Mutex ;
26
31
use tokio_postgres:: types:: ToSql ;
27
32
use tokio_postgres:: types:: Type ;
@@ -102,12 +107,20 @@ pub struct AsyncPgConnection {
102
107
stmt_cache : Arc < Mutex < StmtCache < diesel:: pg:: Pg , Statement > > > ,
103
108
transaction_state : Arc < Mutex < AnsiTransactionManager > > ,
104
109
metadata_cache : Arc < Mutex < PgMetadataCache > > ,
110
+ connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
111
+ shutdown_channel : Option < oneshot:: Sender < ( ) > > ,
105
112
}
106
113
107
114
#[ async_trait:: async_trait]
108
115
impl SimpleAsyncConnection for AsyncPgConnection {
109
116
async fn batch_execute ( & mut self , query : & str ) -> QueryResult < ( ) > {
110
- Ok ( self . conn . batch_execute ( query) . await . map_err ( ErrorHelper ) ?)
117
+ let connection_future = self . connection_future . as_ref ( ) . map ( |rx| rx. resubscribe ( ) ) ;
118
+ let batch_execute = self
119
+ . conn
120
+ . batch_execute ( query)
121
+ . map_err ( ErrorHelper )
122
+ . map_err ( Into :: into) ;
123
+ drive_future ( connection_future, batch_execute) . await
111
124
}
112
125
}
113
126
@@ -124,29 +137,37 @@ impl AsyncConnection for AsyncPgConnection {
124
137
let ( client, connection) = tokio_postgres:: connect ( database_url, tokio_postgres:: NoTls )
125
138
. await
126
139
. map_err ( ErrorHelper ) ?;
140
+ let ( tx, rx) = tokio:: sync:: broadcast:: channel ( 1 ) ;
141
+ let ( shutdown_tx, shutdown_rx) = tokio:: sync:: oneshot:: channel ( ) ;
127
142
tokio:: spawn ( async move {
128
- if let Err ( e) = connection. await {
129
- eprintln ! ( "connection error: {e}" ) ;
143
+ match futures_util:: future:: select ( shutdown_rx, connection) . await {
144
+ Either :: Left ( _) | Either :: Right ( ( Ok ( _) , _) ) => { }
145
+ Either :: Right ( ( Err ( e) , _) ) => {
146
+ let _ = tx. send ( Arc :: new ( e) ) ;
147
+ }
130
148
}
131
149
} ) ;
132
- Self :: try_from ( client) . await
150
+
151
+ Self :: setup ( client, Some ( rx) , Some ( shutdown_tx) ) . await
133
152
}
134
153
135
154
fn load < ' conn , ' query , T > ( & ' conn mut self , source : T ) -> Self :: LoadFuture < ' conn , ' query >
136
155
where
137
156
T : AsQuery + ' query ,
138
157
T :: Query : QueryFragment < Self :: Backend > + QueryId + ' query ,
139
158
{
159
+ let connection_future = self . connection_future . as_ref ( ) . map ( |rx| rx. resubscribe ( ) ) ;
140
160
let query = source. as_query ( ) ;
141
- self . with_prepared_statement ( query, |conn, stmt, binds| async move {
161
+ let load_future = self . with_prepared_statement ( query, |conn, stmt, binds| async move {
142
162
let res = conn. query_raw ( & stmt, binds) . await . map_err ( ErrorHelper ) ?;
143
163
144
164
Ok ( res
145
165
. map_err ( |e| diesel:: result:: Error :: from ( ErrorHelper ( e) ) )
146
166
. map_ok ( PgRow :: new)
147
167
. boxed ( ) )
148
- } )
149
- . boxed ( )
168
+ } ) ;
169
+
170
+ drive_future ( connection_future, load_future) . boxed ( )
150
171
}
151
172
152
173
fn execute_returning_count < ' conn , ' query , T > (
@@ -156,7 +177,8 @@ impl AsyncConnection for AsyncPgConnection {
156
177
where
157
178
T : QueryFragment < Self :: Backend > + QueryId + ' query ,
158
179
{
159
- self . with_prepared_statement ( source, |conn, stmt, binds| async move {
180
+ let connection_future = self . connection_future . as_ref ( ) . map ( |rx| rx. resubscribe ( ) ) ;
181
+ let execute = self . with_prepared_statement ( source, |conn, stmt, binds| async move {
160
182
let binds = binds
161
183
. iter ( )
162
184
. map ( |b| b as & ( dyn ToSql + Sync ) )
@@ -166,8 +188,8 @@ impl AsyncConnection for AsyncPgConnection {
166
188
. await
167
189
. map_err ( ErrorHelper ) ?;
168
190
Ok ( res as usize )
169
- } )
170
- . boxed ( )
191
+ } ) ;
192
+ drive_future ( connection_future , execute ) . boxed ( )
171
193
}
172
194
173
195
fn transaction_state ( & mut self ) -> & mut AnsiTransactionManager {
@@ -182,15 +204,21 @@ impl AsyncConnection for AsyncPgConnection {
182
204
}
183
205
}
184
206
207
+ impl Drop for AsyncPgConnection {
208
+ fn drop ( & mut self ) {
209
+ if let Some ( tx) = self . shutdown_channel . take ( ) {
210
+ let _ = tx. send ( ( ) ) ;
211
+ }
212
+ }
213
+ }
214
+
185
215
#[ inline( always) ]
186
216
fn update_transaction_manager_status < T > (
187
217
query_result : QueryResult < T > ,
188
218
transaction_manager : & mut AnsiTransactionManager ,
189
219
) -> QueryResult < T > {
190
- if let Err ( diesel:: result:: Error :: DatabaseError (
191
- diesel:: result:: DatabaseErrorKind :: SerializationFailure ,
192
- _,
193
- ) ) = query_result
220
+ if let Err ( diesel:: result:: Error :: DatabaseError ( DatabaseErrorKind :: SerializationFailure , _) ) =
221
+ query_result
194
222
{
195
223
transaction_manager
196
224
. status
@@ -270,11 +298,21 @@ impl AsyncPgConnection {
270
298
271
299
/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
272
300
pub async fn try_from ( conn : tokio_postgres:: Client ) -> ConnectionResult < Self > {
301
+ Self :: setup ( conn, None , None ) . await
302
+ }
303
+
304
+ async fn setup (
305
+ conn : tokio_postgres:: Client ,
306
+ connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
307
+ shutdown_channel : Option < oneshot:: Sender < ( ) > > ,
308
+ ) -> ConnectionResult < Self > {
273
309
let mut conn = Self {
274
310
conn : Arc :: new ( conn) ,
275
311
stmt_cache : Arc :: new ( Mutex :: new ( StmtCache :: new ( ) ) ) ,
276
312
transaction_state : Arc :: new ( Mutex :: new ( AnsiTransactionManager :: default ( ) ) ) ,
277
313
metadata_cache : Arc :: new ( Mutex :: new ( PgMetadataCache :: new ( ) ) ) ,
314
+ connection_future,
315
+ shutdown_channel,
278
316
} ;
279
317
conn. set_config_options ( )
280
318
. await
@@ -470,6 +508,29 @@ async fn lookup_type(
470
508
Ok ( ( r. get ( 0 ) , r. get ( 1 ) ) )
471
509
}
472
510
511
+ async fn drive_future < R > (
512
+ connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
513
+ client_future : impl Future < Output = Result < R , diesel:: result:: Error > > ,
514
+ ) -> Result < R , diesel:: result:: Error > {
515
+ if let Some ( mut connection_future) = connection_future {
516
+ let client_future = std:: pin:: pin!( client_future) ;
517
+ let connection_future = std:: pin:: pin!( connection_future. recv( ) ) ;
518
+ match futures_util:: future:: select ( client_future, connection_future) . await {
519
+ Either :: Left ( ( res, _) ) => res,
520
+ // we got an error from the background task
521
+ // return it to the user
522
+ Either :: Right ( ( Ok ( e) , _) ) => Err ( self :: error_helper:: from_tokio_postgres_error ( e) ) ,
523
+ // seems like the background thread died for whatever reason
524
+ Either :: Right ( ( Err ( e) , _) ) => Err ( diesel:: result:: Error :: DatabaseError (
525
+ DatabaseErrorKind :: UnableToSendCommand ,
526
+ Box :: new ( e. to_string ( ) ) ,
527
+ ) ) ,
528
+ }
529
+ } else {
530
+ client_future. await
531
+ }
532
+ }
533
+
473
534
#[ cfg( any(
474
535
feature = "deadpool" ,
475
536
feature = "bb8" ,
0 commit comments