Skip to content

Commit ff697f2

Browse files
Add graceful shutdown of connections.
1 parent 3db4b31 commit ff697f2

File tree

2 files changed

+98
-71
lines changed

2 files changed

+98
-71
lines changed

ext/hyper_ruby/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ tokio-stream = { version = "0.1", features = ["net"] }
1717
crossbeam-channel = "0.5.14"
1818
rb-sys = "0.9.110"
1919
hyper = { version = "1.0", features = ["http1", "http2", "server"] }
20-
hyper-util = { version = "0.1", features = ["tokio", "server", "server-auto", "http1", "http2"] }
20+
hyper-util = { version = "0.1", features = ["tokio", "server", "server-graceful", "server-auto", "http1", "http2"] }
2121
http-body-util = "0.1.2"
2222
jemallocator = { version = "0.5.4", features = ["disable_initial_exec_tls"] }
2323
futures = "0.3.31"

ext/hyper_ruby/src/lib.rs

Lines changed: 97 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ mod response;
33
mod gvl_helpers;
44
mod grpc;
55

6+
use hyper_util::server::graceful::GracefulShutdown;
67
use request::{Request, GrpcRequest};
78
use response::{Response, GrpcResponse};
89
use gvl_helpers::nogvl;
@@ -11,11 +12,12 @@ use magnus::block::block_proc;
1112
use magnus::typed_data::Obj;
1213
use magnus::{function, method, prelude::*, Error as MagnusError, IntoValue, Ruby, Value, RString};
1314
use bytes::Bytes;
15+
use tokio::io::{AsyncRead, AsyncWrite};
1416

1517
use std::cell::RefCell;
1618
use std::net::SocketAddr;
1719

18-
use tokio::net::UnixListener;
20+
use tokio::net::{TcpListener, UnixListener};
1921

2022
use std::sync::Arc;
2123
use tokio::sync::{Mutex, oneshot};
@@ -31,18 +33,45 @@ use http_body_util::BodyExt;
3133

3234
use jemallocator::Jemalloc;
3335

34-
use log::{debug, info, warn};
36+
use log::{debug, info, warn, error};
3537

3638
use env_logger;
3739
use crate::response::BodyWithTrailers;
3840
use std::sync::Once;
3941
use tokio::time::timeout;
4042

43+
use std::io;
44+
45+
use tokio::sync::broadcast;
46+
4147
static LOGGER_INIT: Once = Once::new();
4248

4349
#[global_allocator]
4450
static GLOBAL: Jemalloc = Jemalloc;
4551

52+
trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {}
53+
impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncStream for T {}
54+
55+
enum Listener {
56+
Unix(UnixListener),
57+
Tcp(TcpListener),
58+
}
59+
60+
impl Listener {
61+
async fn accept(&self) -> io::Result<(Box<dyn AsyncStream>, SocketAddr)> {
62+
match self {
63+
Listener::Unix(l) => {
64+
let (stream, _) = l.accept().await?;
65+
Ok((Box::new(stream), "0.0.0.0:0".parse().unwrap()))
66+
}
67+
Listener::Tcp(l) => {
68+
let (stream, addr) = l.accept().await?;
69+
Ok((Box::new(stream), addr))
70+
}
71+
}
72+
}
73+
}
74+
4675
#[derive(Clone)]
4776
struct ServerConfig {
4877
bind_address: String,
@@ -75,18 +104,19 @@ struct Server {
75104
work_rx: RefCell<Option<crossbeam_channel::Receiver<RequestWithCompletion>>>,
76105
work_tx: RefCell<Option<Arc<crossbeam_channel::Sender<RequestWithCompletion>>>>,
77106
runtime: RefCell<Option<Arc<tokio::runtime::Runtime>>>,
107+
shutdown: RefCell<Option<broadcast::Sender<()>>>,
78108
}
79109

80110
impl Server {
81111
pub fn new() -> Self {
82112
let (work_tx, work_rx) = crossbeam_channel::bounded(1000);
83-
84113
Self {
85114
server_handle: Arc::new(Mutex::new(None)),
86115
config: RefCell::new(ServerConfig::new()),
87116
work_rx: RefCell::new(Some(work_rx)),
88117
work_tx: RefCell::new(Some(Arc::new(work_tx))),
89118
runtime: RefCell::new(None),
119+
shutdown: RefCell::new(None),
90120
}
91121
}
92122

@@ -211,6 +241,9 @@ impl Server {
211241
.ok_or_else(|| MagnusError::new(magnus::exception::runtime_error(), "Work channel not initialized"))?
212242
.clone();
213243

244+
let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
245+
*self.shutdown.borrow_mut() = Some(shutdown_tx);
246+
214247
let mut rt_builder = tokio::runtime::Builder::new_multi_thread();
215248

216249
rt_builder.enable_all();
@@ -225,65 +258,94 @@ impl Server {
225258

226259
*self.runtime.borrow_mut() = Some(rt.clone());
227260

228-
rt.block_on(async {
229-
let work_tx = work_tx.clone();
230-
261+
262+
rt.block_on(async move {
231263
let server_task = tokio::spawn(async move {
232264
let timer = hyper_util::rt::TokioTimer::new();
265+
let mut builder = auto::Builder::new(hyper_util::rt::TokioExecutor::new());
266+
builder.http1()
267+
.header_read_timeout(std::time::Duration::from_millis(config.recv_timeout))
268+
.timer(timer.clone());
269+
builder.http2()
270+
.keep_alive_interval(std::time::Duration::from_secs(10))
271+
.timer(timer);
272+
273+
let listener = if config.bind_address.starts_with("unix:") {
274+
Listener::Unix(UnixListener::bind(config.bind_address.trim_start_matches("unix:")).unwrap())
275+
} else {
276+
let addr: SocketAddr = config.bind_address.parse().expect("invalid address format");
277+
Listener::Tcp(TcpListener::bind(addr).await.unwrap())
278+
};
233279

234-
if config.bind_address.starts_with("unix:") {
235-
let path = config.bind_address.trim_start_matches("unix:");
236-
let listener = UnixListener::bind(path).unwrap();
237-
238-
loop {
239-
let (stream, _) = listener.accept().await.unwrap();
240-
let work_tx = work_tx.clone();
241-
let timer = timer.clone();
242-
243-
tokio::task::spawn(async move {
244-
handle_connection(stream, work_tx, config.recv_timeout, timer).await;
245-
});
280+
let graceful_shutdown = GracefulShutdown::new();
281+
let mut shutdown_rx = shutdown_rx;
282+
283+
loop {
284+
tokio::select! {
285+
Ok((stream, _)) = listener.accept() => {
286+
info!("New connection established");
287+
288+
let io = TokioIo::new(stream);
289+
290+
debug!("Setting up connection");
291+
292+
let builder = builder.clone();
293+
let work_tx = work_tx.clone();
294+
let conn = builder.serve_connection(io, service_fn(move |req: HyperRequest<Incoming>| {
295+
debug!("Service handling request");
296+
handle_request(req, work_tx.clone(), config.recv_timeout)
297+
}));
298+
let fut = graceful_shutdown.watch(conn.into_owned());
299+
tokio::task::spawn(async move {
300+
if let Err(err) = fut.await {
301+
warn!("Error serving connection: {:?}", err);
302+
}
303+
});
304+
},
305+
_ = shutdown_rx.recv() => {
306+
debug!("Graceful shutdown requested; shutting down");
307+
break;
308+
}
246309
}
247-
} else {
248-
let addr: SocketAddr = config.bind_address.parse()
249-
.expect("invalid address format");
250-
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
251-
252-
loop {
253-
let (stream, _) = listener.accept().await.unwrap();
254-
let work_tx = work_tx.clone();
255-
let timer = timer.clone();
256-
257-
tokio::task::spawn(async move {
258-
handle_connection(stream, work_tx, config.recv_timeout, timer).await;
259-
});
310+
}
311+
312+
tokio::select! {
313+
_ = graceful_shutdown.shutdown() => {
314+
debug!("all connections gracefully closed");
315+
},
316+
_ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
317+
error!("timed out wait for all connections to close");
260318
}
261319
}
262320
});
263321

264322
let mut handle = self.server_handle.lock().await;
265323
*handle = Some(server_task);
266-
324+
267325
Ok::<(), MagnusError>(())
268326
})?;
269-
327+
270328
Ok(())
271329
}
272330

273331
pub fn stop(&self) -> Result<(), MagnusError> {
274-
// Use the stored runtime instead of creating a new one
275332
if let Some(rt) = self.runtime.borrow().as_ref() {
333+
if let Some(shutdown) = self.shutdown.borrow().as_ref() {
334+
let _ = shutdown.send(());
335+
}
336+
276337
rt.block_on(async {
277338
let mut handle = self.server_handle.lock().await;
278339
if let Some(task) = handle.take() {
279-
task.abort();
340+
task.await.unwrap_or_else(|e| warn!("Server task failed: {:?}", e));
280341
}
281342
});
282343
}
283344

284345
// Drop the channel and runtime
285346
self.work_tx.borrow_mut().take();
286347
self.runtime.borrow_mut().take();
348+
self.shutdown.borrow_mut().take();
287349

288350
let bind_address = self.config.borrow().bind_address.clone();
289351
if bind_address.starts_with("unix:") {
@@ -371,41 +433,6 @@ fn create_timeout_response() -> HyperResponse<BodyWithTrailers> {
371433
.unwrap()
372434
}
373435

374-
async fn handle_connection(
375-
stream: impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
376-
work_tx: Arc<crossbeam_channel::Sender<RequestWithCompletion>>,
377-
recv_timeout: u64,
378-
timer: hyper_util::rt::TokioTimer,
379-
) {
380-
info!("New connection established");
381-
382-
let service = service_fn(move |req: HyperRequest<Incoming>| {
383-
debug!("Service handling request");
384-
let work_tx = work_tx.clone();
385-
handle_request(req, work_tx, recv_timeout)
386-
});
387-
388-
let io = TokioIo::new(stream);
389-
390-
debug!("Setting up connection");
391-
let mut builder = auto::Builder::new(hyper_util::rt::TokioExecutor::new());
392-
393-
builder.http1()
394-
.header_read_timeout(std::time::Duration::from_millis(recv_timeout))
395-
.timer(timer.clone());
396-
397-
builder.http2()
398-
.keep_alive_interval(std::time::Duration::from_secs(10))
399-
.timer(timer);
400-
401-
if let Err(err) = builder
402-
.serve_connection(io, service)
403-
.await
404-
{
405-
warn!("Error serving connection: {:?}", err);
406-
}
407-
}
408-
409436
// Helper function to create error responses
410437
fn create_error_response(error_message: &str) -> HyperResponse<BodyWithTrailers> {
411438
// For non-gRPC requests, return a plain HTTP error

0 commit comments

Comments
 (0)