Skip to content

Commit

Permalink
Make ServeDir infallible (#283)
Browse files Browse the repository at this point in the history
* Make `ServeDir` infallible

* add missing docs

* depend on tracing for `fs`

* add note about requiring infallible fallbacks

* clean up

* format
  • Loading branch information
davidpdrsn authored Dec 2, 2022
1 parent ffcdec5 commit f8743bf
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 39 deletions.
5 changes: 4 additions & 1 deletion tower-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Changed

- None.
- **fs:** `ServeDir` and `ServeFile`'s error types are now `Infallible` and any IO errors
will be converted into responses. Use `try_call` to generate error responses manually
- **fs:** `ServeDir::fallback` and `ServeDir::not_found_service` now requires
the fallback service to use `Infallible` as its error type

## Removed

Expand Down
2 changes: 1 addition & 1 deletion tower-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ auth = ["base64"]
catch-panic = ["tracing", "futures-util/std"]
cors = []
follow-redirect = ["iri-string", "tower/util"]
fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status", "futures-util/alloc"]
fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status", "futures-util/alloc", "tracing"]
limit = []
map-request-body = []
map-response-body = []
Expand Down
38 changes: 38 additions & 0 deletions tower-http/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,41 @@ macro_rules! opaque_body {
}
};
}

#[allow(unused_macros)]
macro_rules! opaque_future {
($(#[$m:meta])* pub type $name:ident<$($param:ident),+> = $actual:ty;) => {
pin_project_lite::pin_project! {
$(#[$m])*
pub struct $name<$($param),+> {
#[pin]
inner: $actual
}
}

impl<$($param),+> $name<$($param),+> {
pub(crate) fn new(inner: $actual) -> Self {
Self {
inner
}
}
}

impl<$($param),+> std::fmt::Debug for $name<$($param),+> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple(stringify!($name)).field(&format_args!("...")).finish()
}
}

impl<$($param),+> std::future::Future for $name<$($param),+>
where
$actual: std::future::Future,
{
type Output = <$actual as std::future::Future>::Output;
#[inline]
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
}
}
14 changes: 6 additions & 8 deletions tower-http/src/services/fs/serve_dir/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use http::{
use http_body::{Body, Empty, Full};
use pin_project_lite::pin_project;
use std::{
convert::Infallible,
future::Future,
io,
pin::Pin,
Expand All @@ -23,7 +24,7 @@ use std::{
use tower_service::Service;

pin_project! {
/// Response future of [`ServeDir`].
/// Response future of [`ServeDir::try_call`].
pub struct ResponseFuture<ReqBody, F = DefaultServeDirFallback> {
#[pin]
pub(super) inner: ResponseFutureInner<ReqBody, F>,
Expand Down Expand Up @@ -67,7 +68,7 @@ pin_project! {
fallback_and_request: Option<(F, Request<ReqBody>)>,
},
FallbackFuture {
future: BoxFuture<'static, io::Result<Response<ResponseBody>>>,
future: BoxFuture<'static, Result<Response<ResponseBody>, Infallible>>,
},
InvalidPath {
fallback_and_request: Option<(F, Request<ReqBody>)>,
Expand All @@ -78,8 +79,7 @@ pin_project! {

impl<F, ReqBody, ResBody> Future for ResponseFuture<ReqBody, F>
where
F: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
F::Error: Into<io::Error>,
F: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible> + Clone,
F::Future: Send + 'static,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
Expand Down Expand Up @@ -139,7 +139,7 @@ where
},

ResponseFutureInnerProj::FallbackFuture { future } => {
break Pin::new(future).poll(cx)
break Pin::new(future).poll(cx).map_err(|err| match err {})
}

ResponseFutureInnerProj::InvalidPath {
Expand Down Expand Up @@ -181,15 +181,13 @@ pub(super) fn call_fallback<F, B, FResBody>(
req: Request<B>,
) -> ResponseFutureInner<B, F>
where
F: Service<Request<B>, Response = Response<FResBody>> + Clone,
F::Error: Into<io::Error>,
F: Service<Request<B>, Response = Response<FResBody>, Error = Infallible> + Clone,
F::Future: Send + 'static,
FResBody: http_body::Body<Data = Bytes> + Send + 'static,
FResBody::Error: Into<BoxError>,
{
let future = fallback
.call(req)
.err_into()
.map_ok(|response| {
response
.map(|body| {
Expand Down
153 changes: 127 additions & 26 deletions tower-http/src/services/fs/serve_dir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use crate::{
set_status::SetStatus,
};
use bytes::Bytes;
use futures_util::FutureExt;
use http::{header, HeaderValue, Method, Request, Response, StatusCode};
use http_body::{combinators::UnsyncBoxBody, Empty};
use http_body::{combinators::UnsyncBoxBody, Body, Empty};
use percent_encoding::percent_decode;
use std::{
convert::Infallible,
Expand Down Expand Up @@ -254,30 +255,80 @@ impl<F> ServeDir<F> {
self.call_fallback_on_method_not_allowed = call_fallback;
self
}
}

impl<ReqBody, F, FResBody> Service<Request<ReqBody>> for ServeDir<F>
where
F: Service<Request<ReqBody>, Response = Response<FResBody>> + Clone,
F::Error: Into<io::Error>,
F::Future: Send + 'static,
FResBody: http_body::Body<Data = Bytes> + Send + 'static,
FResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Response = Response<ResponseBody>;
type Error = io::Error;
type Future = ResponseFuture<ReqBody, F>;

#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Some(fallback) = &mut self.fallback {
fallback.poll_ready(cx).map_err(Into::into)
} else {
Poll::Ready(Ok(()))
}
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
/// Call the service and get a future that contains any `std::io::Error` that might have
/// happened.
///
/// By default `<ServeDir as Service<_>>::call` will handle IO errors and convert them into
/// responses. It does that by converting [`std::io::ErrorKind::NotFound`] and
/// [`std::io::ErrorKind::PermissionDenied`] to `404 Not Found` and any other error to `500
/// Internal Server Error`. The error will also be logged with `tracing`.
///
/// If you want to manually control how the error response is generated you can make a new
/// service that wraps a `ServeDir` and calls `try_call` instead of `call`.
///
/// # Example
///
/// ```
/// use tower_http::services::ServeDir;
/// use std::{io, convert::Infallible};
/// use http::{Request, Response, StatusCode};
/// use http_body::{combinators::UnsyncBoxBody, Body as _};
/// use hyper::Body;
/// use bytes::Bytes;
/// use tower::{service_fn, ServiceExt, BoxError};
///
/// async fn serve_dir(
/// request: Request<Body>
/// ) -> Result<Response<UnsyncBoxBody<Bytes, BoxError>>, Infallible> {
/// let mut service = ServeDir::new("assets");
///
/// // You only need to worry about backpressure, and thus call `ServiceExt::ready`, if
/// // your adding a fallback to `ServeDir` that cares about backpressure.
/// //
/// // Its shown here for demonstration but you can do `service.try_call(request)`
/// // otherwise
/// let ready_service = match ServiceExt::<Request<Body>>::ready(&mut service).await {
/// Ok(ready_service) => ready_service,
/// Err(infallible) => match infallible {},
/// };
///
/// match ready_service.try_call(request).await {
/// Ok(response) => {
/// Ok(response.map(|body| body.map_err(Into::into).boxed_unsync()))
/// }
/// Err(err) => {
/// let body = Body::from("Something went wrong...")
/// .map_err(Into::into)
/// .boxed_unsync();
/// let response = Response::builder()
/// .status(StatusCode::INTERNAL_SERVER_ERROR)
/// .body(body)
/// .unwrap();
/// Ok(response)
/// }
/// }
/// }
///
/// # async {
/// // Run our service using `hyper`
/// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000));
/// hyper::Server::bind(&addr)
/// .serve(tower::make::Shared::new(service_fn(serve_dir)))
/// .await
/// .expect("server error");
/// # };
/// ```
pub fn try_call<ReqBody, FResBody>(
&mut self,
req: Request<ReqBody>,
) -> ResponseFuture<ReqBody, F>
where
F: Service<Request<ReqBody>, Response = Response<FResBody>, Error = Infallible> + Clone,
F::Future: Send + 'static,
FResBody: http_body::Body<Data = Bytes> + Send + 'static,
FResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
if req.method() != Method::GET && req.method() != Method::HEAD {
if self.call_fallback_on_method_not_allowed {
if let Some(fallback) = &mut self.fallback {
Expand Down Expand Up @@ -350,6 +401,56 @@ where
}
}

impl<ReqBody, F, FResBody> Service<Request<ReqBody>> for ServeDir<F>
where
F: Service<Request<ReqBody>, Response = Response<FResBody>, Error = Infallible> + Clone,
F::Future: Send + 'static,
FResBody: http_body::Body<Data = Bytes> + Send + 'static,
FResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Response = Response<ResponseBody>;
type Error = Infallible;
type Future = InfallibleResponseFuture<ReqBody, F>;

#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Some(fallback) = &mut self.fallback {
fallback.poll_ready(cx)
} else {
Poll::Ready(Ok(()))
}
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let future = self
.try_call(req)
.map(|result: Result<_, _>| -> Result<_, Infallible> {
let response = result.unwrap_or_else(|err| {
tracing::error!(error = %err, "Failed to read file");

let body =
ResponseBody::new(Empty::new().map_err(|err| match err {}).boxed_unsync());
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(body)
.unwrap()
});
Ok(response)
} as _);

InfallibleResponseFuture::new(future)
}
}

opaque_future! {
/// Response future of [`ServeDir`].
pub type InfallibleResponseFuture<ReqBody, F> =
futures_util::future::Map<
ResponseFuture<ReqBody, F>,
fn(Result<Response<ResponseBody>, io::Error>) -> Result<Response<ResponseBody>, Infallible>,
>;
}

// Allow the ServeDir service to be used in the ServeFile service
// with almost no overhead
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -414,8 +515,8 @@ where
ReqBody: Send + 'static,
{
type Response = Response<ResponseBody>;
type Error = io::Error;
type Future = ResponseFuture<ReqBody>;
type Error = Infallible;
type Future = InfallibleResponseFuture<ReqBody, Self>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.0 {}
Expand Down
6 changes: 3 additions & 3 deletions tower-http/src/services/fs/serve_dir/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ async fn last_modified() {

#[tokio::test]
async fn with_fallback_svc() {
async fn fallback<B>(req: Request<B>) -> io::Result<Response<Body>> {
async fn fallback<B>(req: Request<B>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::from(format!(
"from fallback {}",
req.uri().path()
Expand Down Expand Up @@ -644,7 +644,7 @@ async fn method_not_allowed() {

#[tokio::test]
async fn calling_fallback_on_not_allowed() {
async fn fallback<B>(req: Request<B>) -> io::Result<Response<Body>> {
async fn fallback<B>(req: Request<B>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::from(format!(
"from fallback {}",
req.uri().path()
Expand All @@ -670,7 +670,7 @@ async fn calling_fallback_on_not_allowed() {

#[tokio::test]
async fn with_fallback_svc_and_not_append_index_html_on_directories() {
async fn fallback<B>(req: Request<B>) -> io::Result<Response<Body>> {
async fn fallback<B>(req: Request<B>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::from(format!(
"from fallback {}",
req.uri().path()
Expand Down
14 changes: 14 additions & 0 deletions tower-http/src/services/fs/serve_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ impl ServeFile {
pub fn with_buf_chunk_size(self, chunk_size: usize) -> Self {
Self(self.0.with_buf_chunk_size(chunk_size))
}

/// Call the service and get a future that contains any `std::io::Error` that might have
/// happened.
///
/// See [`ServeDir::try_call`] for more details.
pub fn try_call<ReqBody>(
&mut self,
req: Request<ReqBody>,
) -> super::serve_dir::future::ResponseFuture<ReqBody>
where
ReqBody: Send + 'static,
{
self.0.try_call(req)
}
}

impl<ReqBody> Service<Request<ReqBody>> for ServeFile
Expand Down

0 comments on commit f8743bf

Please sign in to comment.