Skip to content

Commit

Permalink
fix(server): start header read timeout immediately (#3185)
Browse files Browse the repository at this point in the history
The `http1_header_read_timeout` used to start once there was a single
read of headers. This change makes it start the timer immediately, right
when the connection is estabilished.
  • Loading branch information
seanmonstar authored Jun 3, 2024
1 parent f9aa697 commit 0eb1b6c
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 159 deletions.
62 changes: 49 additions & 13 deletions src/proto/h1/conn.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::fmt;
#[cfg(feature = "server")]
use std::future::Future;
use std::io;
use std::marker::{PhantomData, Unpin};
use std::pin::Pin;
use std::task::{Context, Poll};
#[cfg(feature = "server")]
use std::time::Duration;
use std::time::{Duration, Instant};

use crate::rt::{Read, Write};
use bytes::{Buf, Bytes};
Expand Down Expand Up @@ -209,33 +211,67 @@ where
debug_assert!(self.can_read_head());
trace!("Conn::read_head");

let msg = match ready!(self.io.parse::<T>(
#[cfg(feature = "server")]
if !self.state.h1_header_read_timeout_running {
if let Some(h1_header_read_timeout) = self.state.h1_header_read_timeout {
let deadline = Instant::now() + h1_header_read_timeout;
self.state.h1_header_read_timeout_running = true;
match self.state.h1_header_read_timeout_fut {
Some(ref mut h1_header_read_timeout_fut) => {
trace!("resetting h1 header read timeout timer");
self.state.timer.reset(h1_header_read_timeout_fut, deadline);
}
None => {
trace!("setting h1 header read timeout timer");
self.state.h1_header_read_timeout_fut =
Some(self.state.timer.sleep_until(deadline));
}
}
}
}

let msg = match self.io.parse::<T>(
cx,
ParseContext {
cached_headers: &mut self.state.cached_headers,
req_method: &mut self.state.method,
h1_parser_config: self.state.h1_parser_config.clone(),
h1_max_headers: self.state.h1_max_headers,
#[cfg(feature = "server")]
h1_header_read_timeout: self.state.h1_header_read_timeout,
#[cfg(feature = "server")]
h1_header_read_timeout_fut: &mut self.state.h1_header_read_timeout_fut,
#[cfg(feature = "server")]
h1_header_read_timeout_running: &mut self.state.h1_header_read_timeout_running,
#[cfg(feature = "server")]
timer: self.state.timer.clone(),
preserve_header_case: self.state.preserve_header_case,
#[cfg(feature = "ffi")]
preserve_header_order: self.state.preserve_header_order,
h09_responses: self.state.h09_responses,
#[cfg(feature = "ffi")]
on_informational: &mut self.state.on_informational,
},
) {
Poll::Ready(Ok(msg)) => msg,
Poll::Ready(Err(e)) => return self.on_read_head_error(e),
Poll::Pending => {
#[cfg(feature = "server")]
if self.state.h1_header_read_timeout_running {
if let Some(ref mut h1_header_read_timeout_fut) =
self.state.h1_header_read_timeout_fut
{
if Pin::new(h1_header_read_timeout_fut).poll(cx).is_ready() {
self.state.h1_header_read_timeout_running = false;

warn!("read header from client timeout");
return Poll::Ready(Some(Err(crate::Error::new_header_timeout())));
}
}
}

return Poll::Pending;
}
)) {
Ok(msg) => msg,
Err(e) => return self.on_read_head_error(e),
};

#[cfg(feature = "server")]
{
self.state.h1_header_read_timeout_running = false;
self.state.h1_header_read_timeout_fut = None;
}

// Note: don't deconstruct `msg` into local variables, it appears
// the optimizer doesn't remove the extra copies.

Expand Down
38 changes: 1 addition & 37 deletions src/proto/h1/io.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::cmp;
use std::fmt;
#[cfg(feature = "server")]
use std::future::Future;
use std::io::{self, IoSlice};
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -183,14 +181,6 @@ where
req_method: parse_ctx.req_method,
h1_parser_config: parse_ctx.h1_parser_config.clone(),
h1_max_headers: parse_ctx.h1_max_headers,
#[cfg(feature = "server")]
h1_header_read_timeout: parse_ctx.h1_header_read_timeout,
#[cfg(feature = "server")]
h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut,
#[cfg(feature = "server")]
h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running,
#[cfg(feature = "server")]
timer: parse_ctx.timer.clone(),
preserve_header_case: parse_ctx.preserve_header_case,
#[cfg(feature = "ffi")]
preserve_header_order: parse_ctx.preserve_header_order,
Expand All @@ -201,12 +191,6 @@ where
)? {
Some(msg) => {
debug!("parsed {} headers", msg.head.headers.len());

#[cfg(feature = "server")]
{
*parse_ctx.h1_header_read_timeout_running = false;
parse_ctx.h1_header_read_timeout_fut.take();
}
return Poll::Ready(Ok(msg));
}
None => {
Expand All @@ -215,20 +199,6 @@ where
debug!("max_buf_size ({}) reached, closing", max);
return Poll::Ready(Err(crate::Error::new_too_large()));
}

#[cfg(feature = "server")]
if *parse_ctx.h1_header_read_timeout_running {
if let Some(h1_header_read_timeout_fut) =
parse_ctx.h1_header_read_timeout_fut
{
if Pin::new(h1_header_read_timeout_fut).poll(cx).is_ready() {
*parse_ctx.h1_header_read_timeout_running = false;

warn!("read header from client timeout");
return Poll::Ready(Err(crate::Error::new_header_timeout()));
}
}
}
}
}
if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 {
Expand Down Expand Up @@ -660,10 +630,8 @@ enum WriteStrategy {

#[cfg(test)]
mod tests {
use crate::common::io::Compat;
use crate::common::time::Time;

use super::*;
use crate::common::io::Compat;
use std::time::Duration;

use tokio_test::io::Builder as Mock;
Expand Down Expand Up @@ -726,10 +694,6 @@ mod tests {
req_method: &mut None,
h1_parser_config: Default::default(),
h1_max_headers: None,
h1_header_read_timeout: None,
h1_header_read_timeout_fut: &mut None,
h1_header_read_timeout_running: &mut false,
timer: Time::Empty,
preserve_header_case: false,
#[cfg(feature = "ffi")]
preserve_header_order: false,
Expand Down
15 changes: 0 additions & 15 deletions src/proto/h1/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
#[cfg(feature = "server")]
use std::{pin::Pin, time::Duration};

use bytes::BytesMut;
use http::{HeaderMap, Method};
use httparse::ParserConfig;

use crate::body::DecodedLength;
#[cfg(feature = "server")]
use crate::common::time::Time;
use crate::proto::{BodyLength, MessageHead};
#[cfg(feature = "server")]
use crate::rt::Sleep;

pub(crate) use self::conn::Conn;
pub(crate) use self::decode::Decoder;
Expand Down Expand Up @@ -79,14 +72,6 @@ pub(crate) struct ParseContext<'a> {
req_method: &'a mut Option<Method>,
h1_parser_config: ParserConfig,
h1_max_headers: Option<usize>,
#[cfg(feature = "server")]
h1_header_read_timeout: Option<Duration>,
#[cfg(feature = "server")]
h1_header_read_timeout_fut: &'a mut Option<Pin<Box<dyn Sleep>>>,
#[cfg(feature = "server")]
h1_header_read_timeout_running: &'a mut bool,
#[cfg(feature = "server")]
timer: Time,
preserve_header_case: bool,
#[cfg(feature = "ffi")]
preserve_header_order: bool,
Expand Down
Loading

0 comments on commit 0eb1b6c

Please sign in to comment.