@@ -3,6 +3,7 @@ mod response;
3
3
mod gvl_helpers;
4
4
mod grpc;
5
5
6
+ use hyper_util:: server:: graceful:: GracefulShutdown ;
6
7
use request:: { Request , GrpcRequest } ;
7
8
use response:: { Response , GrpcResponse } ;
8
9
use gvl_helpers:: nogvl;
@@ -11,11 +12,12 @@ use magnus::block::block_proc;
11
12
use magnus:: typed_data:: Obj ;
12
13
use magnus:: { function, method, prelude:: * , Error as MagnusError , IntoValue , Ruby , Value , RString } ;
13
14
use bytes:: Bytes ;
15
+ use tokio:: io:: { AsyncRead , AsyncWrite } ;
14
16
15
17
use std:: cell:: RefCell ;
16
18
use std:: net:: SocketAddr ;
17
19
18
- use tokio:: net:: UnixListener ;
20
+ use tokio:: net:: { TcpListener , UnixListener } ;
19
21
20
22
use std:: sync:: Arc ;
21
23
use tokio:: sync:: { Mutex , oneshot} ;
@@ -31,18 +33,45 @@ use http_body_util::BodyExt;
31
33
32
34
use jemallocator:: Jemalloc ;
33
35
34
- use log:: { debug, info, warn} ;
36
+ use log:: { debug, info, warn, error } ;
35
37
36
38
use env_logger;
37
39
use crate :: response:: BodyWithTrailers ;
38
40
use std:: sync:: Once ;
39
41
use tokio:: time:: timeout;
40
42
43
+ use std:: io;
44
+
45
+ use tokio:: sync:: broadcast;
46
+
41
47
static LOGGER_INIT : Once = Once :: new ( ) ;
42
48
43
49
#[ global_allocator]
44
50
static GLOBAL : Jemalloc = Jemalloc ;
45
51
52
+ trait AsyncStream : AsyncRead + AsyncWrite + Unpin + Send { }
53
+ impl < T : AsyncRead + AsyncWrite + Unpin + Send > AsyncStream for T { }
54
+
55
+ enum Listener {
56
+ Unix ( UnixListener ) ,
57
+ Tcp ( TcpListener ) ,
58
+ }
59
+
60
+ impl Listener {
61
+ async fn accept ( & self ) -> io:: Result < ( Box < dyn AsyncStream > , SocketAddr ) > {
62
+ match self {
63
+ Listener :: Unix ( l) => {
64
+ let ( stream, _) = l. accept ( ) . await ?;
65
+ Ok ( ( Box :: new ( stream) , "0.0.0.0:0" . parse ( ) . unwrap ( ) ) )
66
+ }
67
+ Listener :: Tcp ( l) => {
68
+ let ( stream, addr) = l. accept ( ) . await ?;
69
+ Ok ( ( Box :: new ( stream) , addr) )
70
+ }
71
+ }
72
+ }
73
+ }
74
+
46
75
#[ derive( Clone ) ]
47
76
struct ServerConfig {
48
77
bind_address : String ,
@@ -75,18 +104,19 @@ struct Server {
75
104
work_rx : RefCell < Option < crossbeam_channel:: Receiver < RequestWithCompletion > > > ,
76
105
work_tx : RefCell < Option < Arc < crossbeam_channel:: Sender < RequestWithCompletion > > > > ,
77
106
runtime : RefCell < Option < Arc < tokio:: runtime:: Runtime > > > ,
107
+ shutdown : RefCell < Option < broadcast:: Sender < ( ) > > > ,
78
108
}
79
109
80
110
impl Server {
81
111
pub fn new ( ) -> Self {
82
112
let ( work_tx, work_rx) = crossbeam_channel:: bounded ( 1000 ) ;
83
-
84
113
Self {
85
114
server_handle : Arc :: new ( Mutex :: new ( None ) ) ,
86
115
config : RefCell :: new ( ServerConfig :: new ( ) ) ,
87
116
work_rx : RefCell :: new ( Some ( work_rx) ) ,
88
117
work_tx : RefCell :: new ( Some ( Arc :: new ( work_tx) ) ) ,
89
118
runtime : RefCell :: new ( None ) ,
119
+ shutdown : RefCell :: new ( None ) ,
90
120
}
91
121
}
92
122
@@ -211,6 +241,9 @@ impl Server {
211
241
. ok_or_else ( || MagnusError :: new ( magnus:: exception:: runtime_error ( ) , "Work channel not initialized" ) ) ?
212
242
. clone ( ) ;
213
243
244
+ let ( shutdown_tx, shutdown_rx) = broadcast:: channel ( 1 ) ;
245
+ * self . shutdown . borrow_mut ( ) = Some ( shutdown_tx) ;
246
+
214
247
let mut rt_builder = tokio:: runtime:: Builder :: new_multi_thread ( ) ;
215
248
216
249
rt_builder. enable_all ( ) ;
@@ -225,65 +258,94 @@ impl Server {
225
258
226
259
* self . runtime . borrow_mut ( ) = Some ( rt. clone ( ) ) ;
227
260
228
- rt. block_on ( async {
229
- let work_tx = work_tx. clone ( ) ;
230
-
261
+
262
+ rt. block_on ( async move {
231
263
let server_task = tokio:: spawn ( async move {
232
264
let timer = hyper_util:: rt:: TokioTimer :: new ( ) ;
265
+ let mut builder = auto:: Builder :: new ( hyper_util:: rt:: TokioExecutor :: new ( ) ) ;
266
+ builder. http1 ( )
267
+ . header_read_timeout ( std:: time:: Duration :: from_millis ( config. recv_timeout ) )
268
+ . timer ( timer. clone ( ) ) ;
269
+ builder. http2 ( )
270
+ . keep_alive_interval ( std:: time:: Duration :: from_secs ( 10 ) )
271
+ . timer ( timer) ;
272
+
273
+ let listener = if config. bind_address . starts_with ( "unix:" ) {
274
+ Listener :: Unix ( UnixListener :: bind ( config. bind_address . trim_start_matches ( "unix:" ) ) . unwrap ( ) )
275
+ } else {
276
+ let addr: SocketAddr = config. bind_address . parse ( ) . expect ( "invalid address format" ) ;
277
+ Listener :: Tcp ( TcpListener :: bind ( addr) . await . unwrap ( ) )
278
+ } ;
233
279
234
- if config. bind_address . starts_with ( "unix:" ) {
235
- let path = config. bind_address . trim_start_matches ( "unix:" ) ;
236
- let listener = UnixListener :: bind ( path) . unwrap ( ) ;
237
-
238
- loop {
239
- let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
240
- let work_tx = work_tx. clone ( ) ;
241
- let timer = timer. clone ( ) ;
242
-
243
- tokio:: task:: spawn ( async move {
244
- handle_connection ( stream, work_tx, config. recv_timeout , timer) . await ;
245
- } ) ;
280
+ let graceful_shutdown = GracefulShutdown :: new ( ) ;
281
+ let mut shutdown_rx = shutdown_rx;
282
+
283
+ loop {
284
+ tokio:: select! {
285
+ Ok ( ( stream, _) ) = listener. accept( ) => {
286
+ info!( "New connection established" ) ;
287
+
288
+ let io = TokioIo :: new( stream) ;
289
+
290
+ debug!( "Setting up connection" ) ;
291
+
292
+ let builder = builder. clone( ) ;
293
+ let work_tx = work_tx. clone( ) ;
294
+ let conn = builder. serve_connection( io, service_fn( move |req: HyperRequest <Incoming >| {
295
+ debug!( "Service handling request" ) ;
296
+ handle_request( req, work_tx. clone( ) , config. recv_timeout)
297
+ } ) ) ;
298
+ let fut = graceful_shutdown. watch( conn. into_owned( ) ) ;
299
+ tokio:: task:: spawn( async move {
300
+ if let Err ( err) = fut. await {
301
+ warn!( "Error serving connection: {:?}" , err) ;
302
+ }
303
+ } ) ;
304
+ } ,
305
+ _ = shutdown_rx. recv( ) => {
306
+ debug!( "Graceful shutdown requested; shutting down" ) ;
307
+ break ;
308
+ }
246
309
}
247
- } else {
248
- let addr: SocketAddr = config. bind_address . parse ( )
249
- . expect ( "invalid address format" ) ;
250
- let listener = tokio:: net:: TcpListener :: bind ( addr) . await . unwrap ( ) ;
251
-
252
- loop {
253
- let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
254
- let work_tx = work_tx. clone ( ) ;
255
- let timer = timer. clone ( ) ;
256
-
257
- tokio:: task:: spawn ( async move {
258
- handle_connection ( stream, work_tx, config. recv_timeout , timer) . await ;
259
- } ) ;
310
+ }
311
+
312
+ tokio:: select! {
313
+ _ = graceful_shutdown. shutdown( ) => {
314
+ debug!( "all connections gracefully closed" ) ;
315
+ } ,
316
+ _ = tokio:: time:: sleep( std:: time:: Duration :: from_secs( 10 ) ) => {
317
+ error!( "timed out wait for all connections to close" ) ;
260
318
}
261
319
}
262
320
} ) ;
263
321
264
322
let mut handle = self . server_handle . lock ( ) . await ;
265
323
* handle = Some ( server_task) ;
266
-
324
+
267
325
Ok :: < ( ) , MagnusError > ( ( ) )
268
326
} ) ?;
269
-
327
+
270
328
Ok ( ( ) )
271
329
}
272
330
273
331
pub fn stop ( & self ) -> Result < ( ) , MagnusError > {
274
- // Use the stored runtime instead of creating a new one
275
332
if let Some ( rt) = self . runtime . borrow ( ) . as_ref ( ) {
333
+ if let Some ( shutdown) = self . shutdown . borrow ( ) . as_ref ( ) {
334
+ let _ = shutdown. send ( ( ) ) ;
335
+ }
336
+
276
337
rt. block_on ( async {
277
338
let mut handle = self . server_handle . lock ( ) . await ;
278
339
if let Some ( task) = handle. take ( ) {
279
- task. abort ( ) ;
340
+ task. await . unwrap_or_else ( |e| warn ! ( "Server task failed: {:?}" , e ) ) ;
280
341
}
281
342
} ) ;
282
343
}
283
344
284
345
// Drop the channel and runtime
285
346
self . work_tx . borrow_mut ( ) . take ( ) ;
286
347
self . runtime . borrow_mut ( ) . take ( ) ;
348
+ self . shutdown . borrow_mut ( ) . take ( ) ;
287
349
288
350
let bind_address = self . config . borrow ( ) . bind_address . clone ( ) ;
289
351
if bind_address. starts_with ( "unix:" ) {
@@ -371,41 +433,6 @@ fn create_timeout_response() -> HyperResponse<BodyWithTrailers> {
371
433
. unwrap ( )
372
434
}
373
435
374
- async fn handle_connection (
375
- stream : impl tokio:: io:: AsyncRead + tokio:: io:: AsyncWrite + Unpin + Send + ' static ,
376
- work_tx : Arc < crossbeam_channel:: Sender < RequestWithCompletion > > ,
377
- recv_timeout : u64 ,
378
- timer : hyper_util:: rt:: TokioTimer ,
379
- ) {
380
- info ! ( "New connection established" ) ;
381
-
382
- let service = service_fn ( move |req : HyperRequest < Incoming > | {
383
- debug ! ( "Service handling request" ) ;
384
- let work_tx = work_tx. clone ( ) ;
385
- handle_request ( req, work_tx, recv_timeout)
386
- } ) ;
387
-
388
- let io = TokioIo :: new ( stream) ;
389
-
390
- debug ! ( "Setting up connection" ) ;
391
- let mut builder = auto:: Builder :: new ( hyper_util:: rt:: TokioExecutor :: new ( ) ) ;
392
-
393
- builder. http1 ( )
394
- . header_read_timeout ( std:: time:: Duration :: from_millis ( recv_timeout) )
395
- . timer ( timer. clone ( ) ) ;
396
-
397
- builder. http2 ( )
398
- . keep_alive_interval ( std:: time:: Duration :: from_secs ( 10 ) )
399
- . timer ( timer) ;
400
-
401
- if let Err ( err) = builder
402
- . serve_connection ( io, service)
403
- . await
404
- {
405
- warn ! ( "Error serving connection: {:?}" , err) ;
406
- }
407
- }
408
-
409
436
// Helper function to create error responses
410
437
fn create_error_response ( error_message : & str ) -> HyperResponse < BodyWithTrailers > {
411
438
// For non-gRPC requests, return a plain HTTP error
0 commit comments