Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ServeDir infallible #283

Merged
merged 7 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Make ServeDir infallible
  • Loading branch information
davidpdrsn committed Jul 12, 2022
commit f3af43f222e570d05c9780c1b9549f5ef233af73
5 changes: 4 additions & 1 deletion tower-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Added

- Add `NormalizePath` middleware
- **fs:** Add `ServeDir::try_call` and `ServeFile::try_call` to handle how IO
errors are converted to responses

## Changed

- None.
- **fs:** `ServeDir` and `ServeFile`'s error types are now `Infallible` and any IO errors
will be converted into responses. Use `try_call` to generate error responses manually

## Removed

Expand Down
38 changes: 38 additions & 0 deletions tower-http/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,41 @@ macro_rules! opaque_body {
}
};
}

#[allow(unused_macros)]
macro_rules! opaque_future {
($(#[$m:meta])* pub type $name:ident<$($param:ident),+> = $actual:ty;) => {
pin_project_lite::pin_project! {
$(#[$m])*
pub struct $name<$($param),+> {
#[pin]
inner: $actual
}
}

impl<$($param),+> $name<$($param),+> {
pub(crate) fn new(inner: $actual) -> Self {
Self {
inner
}
}
}

impl<$($param),+> std::fmt::Debug for $name<$($param),+> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple(stringify!($name)).field(&format_args!("...")).finish()
}
}

impl<$($param),+> std::future::Future for $name<$($param),+>
where
$actual: std::future::Future,
{
type Output = <$actual as std::future::Future>::Output;
#[inline]
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
}
}
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
12 changes: 5 additions & 7 deletions tower-http/src/services/fs/serve_dir/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use http::{
use http_body::{Body, Empty, Full};
use pin_project_lite::pin_project;
use std::{
convert::Infallible,
future::Future,
io,
pin::Pin,
Expand Down Expand Up @@ -65,7 +66,7 @@ pin_project! {
fallback_and_request: Option<(F, Request<ReqBody>)>,
},
FallbackFuture {
future: BoxFuture<'static, io::Result<Response<ResponseBody>>>,
future: BoxFuture<'static, Result<Response<ResponseBody>, Infallible>>,
},
InvalidPath,
MethodNotAllowed,
Expand All @@ -74,8 +75,7 @@ pin_project! {

impl<F, ReqBody, ResBody> Future for ResponseFuture<ReqBody, F>
where
F: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
F::Error: Into<io::Error>,
F: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible> + Clone,
F::Future: Send + 'static,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
Expand Down Expand Up @@ -135,7 +135,7 @@ where
},

ResponseFutureInnerProj::FallbackFuture { future } => {
break Pin::new(future).poll(cx)
break Pin::new(future).poll(cx).map_err(|err| match err {})
}

ResponseFutureInnerProj::InvalidPath => {
Expand Down Expand Up @@ -171,15 +171,13 @@ pub(super) fn call_fallback<F, B, FResBody>(
req: Request<B>,
) -> ResponseFutureInner<B, F>
where
F: Service<Request<B>, Response = Response<FResBody>> + Clone,
F::Error: Into<io::Error>,
F: Service<Request<B>, Response = Response<FResBody>, Error = Infallible> + Clone,
F::Future: Send + 'static,
FResBody: http_body::Body<Data = Bytes> + Send + 'static,
FResBody::Error: Into<BoxError>,
{
let future = fallback
.call(req)
.err_into()
.map_ok(|response| {
response
.map(|body| {
Expand Down
153 changes: 127 additions & 26 deletions tower-http/src/services/fs/serve_dir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use crate::{
set_status::SetStatus,
};
use bytes::Bytes;
use futures_util::FutureExt;
use http::{header, HeaderValue, Method, Request, Response, StatusCode};
use http_body::{combinators::UnsyncBoxBody, Empty};
use http_body::{combinators::UnsyncBoxBody, Body, Empty};
use percent_encoding::percent_decode;
use std::{
convert::Infallible,
Expand Down Expand Up @@ -254,30 +255,80 @@ impl<F> ServeDir<F> {
self.call_fallback_on_method_not_allowed = call_fallback;
self
}
}

impl<ReqBody, F, FResBody> Service<Request<ReqBody>> for ServeDir<F>
where
F: Service<Request<ReqBody>, Response = Response<FResBody>> + Clone,
F::Error: Into<io::Error>,
F::Future: Send + 'static,
FResBody: http_body::Body<Data = Bytes> + Send + 'static,
FResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Response = Response<ResponseBody>;
type Error = io::Error;
type Future = ResponseFuture<ReqBody, F>;

#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Some(fallback) = &mut self.fallback {
fallback.poll_ready(cx).map_err(Into::into)
} else {
Poll::Ready(Ok(()))
}
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
/// Call the service and get a future that contains any `std::io::Error` that might have
/// happened.
///
/// By default `<ServeDir as Service<_>>::call` will handle IO errors and convert them into
/// responses. It does that by converting [`std::io::ErrorKind::NotFound`] and
/// [`std::io::ErrorKind::PermissionDenied`] to `404 Not Found` and any other error to `500
/// Internal Server Error`. The error will also be logged with `tracing`.
///
/// If you want to manually control how the error response is generated you can make a new
/// service that wraps a `ServeDir` and calls `try_call` instead of `call`.
///
/// # Example
///
/// ```
/// use tower_http::services::ServeDir;
/// use std::{io, convert::Infallible};
/// use http::{Request, Response, StatusCode};
/// use http_body::{combinators::UnsyncBoxBody, Body as _};
/// use hyper::Body;
/// use bytes::Bytes;
/// use tower::{service_fn, ServiceExt, BoxError};
///
/// async fn serve_dir(
/// request: Request<Body>
/// ) -> Result<Response<UnsyncBoxBody<Bytes, BoxError>>, Infallible> {
/// let mut service = ServeDir::new("assets");
///
/// // You only need to worry about backpressure, and thus call `ServiceExt::ready`, if
/// // your adding a fallback to `ServeDir` that cares about backpressure.
/// //
/// // Its shown here for demonstration but you can do `service.try_call(request)`
/// // otherwise
/// let ready_service = match ServiceExt::<Request<Body>>::ready(&mut service).await {
/// Ok(ready_service) => ready_service,
/// Err(infallible) => match infallible {},
/// };
///
/// match ready_service.try_call(request).await {
/// Ok(response) => {
/// Ok(response.map(|body| body.map_err(Into::into).boxed_unsync()))
/// }
/// Err(err) => {
/// let body = Body::from("Something went wrong...")
/// .map_err(Into::into)
/// .boxed_unsync();
/// let response = Response::builder()
/// .status(StatusCode::INTERNAL_SERVER_ERROR)
/// .body(body)
/// .unwrap();
/// Ok(response)
/// }
/// }
/// }
///
/// # async {
/// // Run our service using `hyper`
/// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000));
/// hyper::Server::bind(&addr)
/// .serve(tower::make::Shared::new(service_fn(serve_dir)))
/// .await
/// .expect("server error");
/// # };
/// ```
pub fn try_call<ReqBody, FResBody>(
&mut self,
req: Request<ReqBody>,
) -> ResponseFuture<ReqBody, F>
where
F: Service<Request<ReqBody>, Response = Response<FResBody>, Error = Infallible> + Clone,
F::Future: Send + 'static,
FResBody: http_body::Body<Data = Bytes> + Send + 'static,
FResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
if req.method() != Method::GET && req.method() != Method::HEAD {
if self.call_fallback_on_method_not_allowed {
if let Some(fallback) = &mut self.fallback {
Expand Down Expand Up @@ -350,6 +401,56 @@ where
}
}

impl<ReqBody, F, FResBody> Service<Request<ReqBody>> for ServeDir<F>
where
F: Service<Request<ReqBody>, Response = Response<FResBody>, Error = Infallible> + Clone,
F::Future: Send + 'static,
FResBody: http_body::Body<Data = Bytes> + Send + 'static,
FResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Response = Response<ResponseBody>;
type Error = Infallible;
type Future = InfallibleResponseFuture<ReqBody, F>;

#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Some(fallback) = &mut self.fallback {
fallback.poll_ready(cx)
} else {
Poll::Ready(Ok(()))
}
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let future = self.try_call(req).map(|result| -> Result<_, Infallible> {
match result {
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
Ok(response) => Ok(response),
Err(err) => {
tracing::error!(error = %err, "Failed to read file");

let body =
ResponseBody::new(Empty::new().map_err(|err| match err {}).boxed_unsync());
Nehliin marked this conversation as resolved.
Show resolved Hide resolved
let response = Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(body)
.unwrap();
Ok(response)
}
}
} as _);
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved

InfallibleResponseFuture::new(future)
}
}

opaque_future! {
pub type InfallibleResponseFuture<ReqBody, F> =
futures_util::future::Map<
ResponseFuture<ReqBody, F>,
fn(Result<Response<ResponseBody>, io::Error>) -> Result<Response<ResponseBody>, Infallible>,
>;
}

// Allow the ServeDir service to be used in the ServeFile service
// with almost no overhead
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -414,8 +515,8 @@ where
ReqBody: Send + 'static,
{
type Response = Response<ResponseBody>;
type Error = io::Error;
type Future = ResponseFuture<ReqBody>;
type Error = Infallible;
type Future = InfallibleResponseFuture<ReqBody, Self>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.0 {}
Expand Down
7 changes: 4 additions & 3 deletions tower-http/src/services/fs/serve_dir/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use http::{header, Method, Response};
use http::{Request, StatusCode};
use http_body::Body as HttpBody;
use hyper::Body;
use std::convert::Infallible;
use std::io::{self, Read};
use tower::ServiceExt;

Expand Down Expand Up @@ -586,7 +587,7 @@ async fn last_modified() {

#[tokio::test]
async fn with_fallback_svc() {
async fn fallback<B>(req: Request<B>) -> io::Result<Response<Body>> {
async fn fallback<B>(req: Request<B>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::from(format!(
"from fallback {}",
req.uri().path()
Expand Down Expand Up @@ -643,7 +644,7 @@ async fn method_not_allowed() {

#[tokio::test]
async fn calling_fallback_on_not_allowed() {
async fn fallback<B>(req: Request<B>) -> io::Result<Response<Body>> {
async fn fallback<B>(req: Request<B>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::from(format!(
"from fallback {}",
req.uri().path()
Expand All @@ -669,7 +670,7 @@ async fn calling_fallback_on_not_allowed() {

#[tokio::test]
async fn with_fallback_svc_and_not_append_index_html_on_directories() {
async fn fallback<B>(req: Request<B>) -> io::Result<Response<Body>> {
async fn fallback<B>(req: Request<B>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::from(format!(
"from fallback {}",
req.uri().path()
Expand Down
14 changes: 14 additions & 0 deletions tower-http/src/services/fs/serve_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ impl ServeFile {
pub fn with_buf_chunk_size(self, chunk_size: usize) -> Self {
Self(self.0.with_buf_chunk_size(chunk_size))
}

/// Call the service and get a future that contains any `std::io::Error` that might have
/// happened.
///
/// See [`ServeDir::try_call`] for more details.
pub fn try_call<ReqBody>(
&mut self,
req: Request<ReqBody>,
) -> super::serve_dir::future::ResponseFuture<ReqBody>
where
ReqBody: Send + 'static,
{
self.0.try_call(req)
}
}

impl<ReqBody> Service<Request<ReqBody>> for ServeFile
Expand Down