Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

src: move State off of Request (WIP) #645

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/cookies/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl CookiesMiddleware {
impl<State: Send + Sync + 'static> Middleware<State> for CookiesMiddleware {
fn handle<'a>(
&'a self,
mut ctx: Request<State>,
mut ctx: Request,
next: Next<'a, State>,
) -> BoxFuture<'a, crate::Result> {
Box::pin(async move {
Expand Down Expand Up @@ -117,7 +117,7 @@ impl LazyJar {
}

impl CookieData {
pub(crate) fn from_request<S>(req: &Request<S>) -> Self {
pub(crate) fn from_request(req: &Request) -> Self {
let jar = if let Some(cookie_headers) = req.header(&headers::COOKIE) {
let mut jar = CookieJar::new();
for cookie_header in cookie_headers {
Expand Down
24 changes: 20 additions & 4 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ use crate::{Middleware, Request, Response};
/// Tide routes will also accept endpoints with `Fn` signatures of this form, but using the `async` keyword has better ergonomics.
pub trait Endpoint<State: Send + Sync + 'static>: Send + Sync + 'static {
/// Invoke the endpoint within the given context
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, crate::Result>;
fn call<'a>(&'a self, req: Request, state: State) -> BoxFuture<'a, crate::Result>;
}

pub(crate) type DynEndpoint<State> = dyn Endpoint<State>;

impl<State, F, Fut, Res> Endpoint<State> for F
where
State: Send + Sync + 'static,
F: Send + Sync + 'static + Fn(Request<State>) -> Fut,
F: Send + Sync + 'static + Fn(Request) -> Fut,
Fut: Future<Output = Result<Res>> + Send + 'static,
Res: Into<Response>,
{
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, crate::Result> {
fn call<'a>(&'a self, req: Request, _: State) -> BoxFuture<'a, crate::Result> {
let fut = (self)(req);
Box::pin(async move {
let res = fut.await?;
Expand All @@ -67,6 +67,22 @@ where
}
}

impl<State, F, Fut, Res> Endpoint<State> for F
where
State: Send + Sync + 'static,
F: Send + Sync + 'static + Fn(Request, State) -> Fut,
Fut: Future<Output = Result<Res>> + Send + 'static,
Res: Into<Response>,
{
fn call<'a>(&'a self, req: Request, state: State) -> BoxFuture<'a, crate::Result> {
let fut = (self)(req, state);
Box::pin(async move {
let res = fut.await?;
Ok(res.into())
})
}
}

pub struct MiddlewareEndpoint<E, State> {
endpoint: E,
middleware: Vec<Arc<dyn Middleware<State>>>,
Expand Down Expand Up @@ -109,7 +125,7 @@ where
State: Send + Sync + 'static,
E: Endpoint<State>,
{
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, crate::Result> {
fn call<'a>(&'a self, req: Request, _: State) -> BoxFuture<'a, crate::Result> {
let next = Next {
endpoint: &self.endpoint,
next_middleware: &self.middleware,
Expand Down
4 changes: 2 additions & 2 deletions src/fs/serve_dir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl<State> Endpoint<State> for ServeDir
where
State: Send + Sync + 'static,
{
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, Result> {
fn call<'a>(&'a self, req: Request) -> BoxFuture<'a, Result> {
let path = req.url().path();
let path = path.trim_start_matches(&self.prefix);
let path = path.trim_start_matches('/');
Expand Down Expand Up @@ -81,7 +81,7 @@ mod test {
})
}

fn request(path: &str) -> crate::Request<()> {
fn request(path: &str) -> crate::Request {
let request = crate::http::Request::get(
crate::http::Url::parse(&format!("http://localhost/{}", path)).unwrap(),
);
Expand Down
4 changes: 2 additions & 2 deletions src/log/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl LogMiddleware {
/// Log a request and a response.
async fn log<'a, State: Send + Sync + 'static>(
&'a self,
ctx: Request<State>,
ctx: Request,
next: Next<'a, State>,
) -> crate::Result {
let path = ctx.url().path().to_owned();
Expand Down Expand Up @@ -78,7 +78,7 @@ impl LogMiddleware {
impl<State: Send + Sync + 'static> Middleware<State> for LogMiddleware {
fn handle<'a>(
&'a self,
ctx: Request<State>,
ctx: Request,
next: Next<'a, State>,
) -> BoxFuture<'a, crate::Result> {
Box::pin(async move { self.log(ctx, next).await })
Expand Down
26 changes: 22 additions & 4 deletions src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub trait Middleware<State>: Send + Sync + 'static {
/// Asynchronously handle the request, and return a response.
fn handle<'a>(
&'a self,
request: Request<State>,
request: Request,
next: Next<'a, State>,
) -> BoxFuture<'a, crate::Result>;

Expand All @@ -26,17 +26,35 @@ where
F: Send
+ Sync
+ 'static
+ for<'a> Fn(Request<State>, Next<'a, State>) -> BoxFuture<'a, crate::Result>,
+ for<'a> Fn(Request, Next<'a, State>) -> BoxFuture<'a, crate::Result>,
{
fn handle<'a>(
&'a self,
req: Request<State>,
req: Request,
_: State,
next: Next<'a, State>,
) -> BoxFuture<'a, crate::Result> {
(self)(req, next)
}
}

impl<State, F> Middleware<State> for F
where
F: Send
+ Sync
+ 'static
+ for<'a> Fn(Request, State, Next<'a, State>) -> BoxFuture<'a, crate::Result>,
{
fn handle<'a>(
&'a self,
req: Request,
state: State,
next: Next<'a, State>,
) -> BoxFuture<'a, crate::Result> {
(self)(req, state, next)
}
}

/// The remainder of a middleware chain, including the endpoint.
#[allow(missing_debug_implementations)]
pub struct Next<'a, State> {
Expand All @@ -47,7 +65,7 @@ pub struct Next<'a, State> {
impl<'a, State: Send + Sync + 'static> Next<'a, State> {
/// Asynchronously execute the remaining middleware chain.
#[must_use]
pub fn run(mut self, req: Request<State>) -> BoxFuture<'a, Response> {
pub fn run(mut self, req: Request) -> BoxFuture<'a, Response> {
Box::pin(async move {
if let Some((current, next)) = self.next_middleware.split_first() {
self.next_middleware = next;
Expand Down
2 changes: 1 addition & 1 deletion src/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ where
State: Send + Sync + 'static,
T: AsRef<str> + Send + Sync + 'static,
{
fn call<'a>(&'a self, _req: Request<State>) -> BoxFuture<'a, crate::Result<Response>> {
fn call<'a>(&'a self, _req: Request) -> BoxFuture<'a, crate::Result<Response>> {
let res = self.into();
Box::pin(async move { Ok(res) })
}
Expand Down
39 changes: 15 additions & 24 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use route_recognizer::Params;

use std::ops::Index;
use std::pin::Pin;
use std::{fmt, str::FromStr, sync::Arc};
use std::{fmt, str::FromStr};

use crate::cookies::CookieData;
use crate::http::cookies::Cookie;
Expand All @@ -20,8 +20,7 @@ use crate::Response;
/// Requests also provide *extensions*, a type map primarily used for low-level
/// communication between middleware and endpoints.
#[derive(Debug)]
pub struct Request<State> {
pub(crate) state: Arc<State>,
pub struct Request {
pub(crate) req: http::Request,
pub(crate) route_params: Vec<Params>,
}
Expand All @@ -43,15 +42,13 @@ impl<E: fmt::Debug + fmt::Display> fmt::Display for ParamError<E> {

impl<T: fmt::Debug + fmt::Display> std::error::Error for ParamError<T> {}

impl<State> Request<State> {
impl Request {
/// Create a new `Request`.
pub(crate) fn new(
state: Arc<State>,
req: http_types::Request,
route_params: Vec<Params>,
) -> Self {
Self {
state,
req,
route_params,
}
Expand Down Expand Up @@ -266,12 +263,6 @@ impl<State> Request<State> {
self.req.ext_mut().insert(val)
}

#[must_use]
/// Access application scoped state.
pub fn state(&self) -> &State {
&self.state
}

/// Extract and parse a route parameter by name.
///
/// Returns the results of parsing the parameter according to the inferred
Expand Down Expand Up @@ -524,31 +515,31 @@ impl<State> Request<State> {
}
}

impl<State> AsRef<http::Request> for Request<State> {
impl AsRef<http::Request> for Request {
fn as_ref(&self) -> &http::Request {
&self.req
}
}

impl<State> AsMut<http::Request> for Request<State> {
impl AsMut<http::Request> for Request {
fn as_mut(&mut self) -> &mut http::Request {
&mut self.req
}
}

impl<State> AsRef<http::Headers> for Request<State> {
impl AsRef<http::Headers> for Request {
fn as_ref(&self) -> &http::Headers {
self.req.as_ref()
}
}

impl<State> AsMut<http::Headers> for Request<State> {
impl AsMut<http::Headers> for Request {
fn as_mut(&mut self) -> &mut http::Headers {
self.req.as_mut()
}
}

impl<State> Read for Request<State> {
impl Read for Request {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
Expand All @@ -558,23 +549,23 @@ impl<State> Read for Request<State> {
}
}

impl<State> Into<http::Request> for Request<State> {
impl Into<http::Request> for Request {
fn into(self) -> http::Request {
self.req
}
}

// NOTE: From cannot be implemented for this conversion because `State` needs to
// be constrained by a type.
impl<State: Send + Sync + 'static> Into<Response> for Request<State> {
impl Into<Response> for Request {
fn into(mut self) -> Response {
let mut res = Response::new(StatusCode::Ok);
res.set_body(self.take_body());
res
}
}

impl<State> IntoIterator for Request<State> {
impl IntoIterator for Request {
type Item = (HeaderName, HeaderValues);
type IntoIter = http_types::headers::IntoIter;

Expand All @@ -585,7 +576,7 @@ impl<State> IntoIterator for Request<State> {
}
}

impl<'a, State> IntoIterator for &'a Request<State> {
impl<'a> IntoIterator for &'a Request {
type Item = (&'a HeaderName, &'a HeaderValues);
type IntoIter = http_types::headers::Iter<'a>;

Expand All @@ -595,7 +586,7 @@ impl<'a, State> IntoIterator for &'a Request<State> {
}
}

impl<'a, State> IntoIterator for &'a mut Request<State> {
impl<'a> IntoIterator for &'a mut Request {
type Item = (&'a HeaderName, &'a mut HeaderValues);
type IntoIter = http_types::headers::IterMut<'a>;

Expand All @@ -605,7 +596,7 @@ impl<'a, State> IntoIterator for &'a mut Request<State> {
}
}

impl<State> Index<HeaderName> for Request<State> {
impl Index<HeaderName> for Request {
type Output = HeaderValues;

/// Returns a reference to the value corresponding to the supplied name.
Expand All @@ -619,7 +610,7 @@ impl<State> Index<HeaderName> for Request<State> {
}
}

impl<State> Index<&str> for Request<State> {
impl Index<&str> for Request {
type Output = HeaderValues;

/// Returns a reference to the value corresponding to the supplied name.
Expand Down
4 changes: 1 addition & 3 deletions src/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,8 @@ where
State: Send + Sync + 'static,
E: Endpoint<State>,
{
fn call<'a>(&'a self, req: crate::Request<State>) -> BoxFuture<'a, crate::Result> {
fn call<'a>(&'a self, req: crate::Request, _: State) -> BoxFuture<'a, crate::Result> {
let crate::Request {
state,
mut req,
route_params,
} = req;
Expand All @@ -290,7 +289,6 @@ where
req.url_mut().set_path(&rest);

self.0.call(crate::Request {
state,
req,
route_params,
})
Expand Down
8 changes: 4 additions & 4 deletions src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ impl<State: Send + Sync + 'static> Router<State> {
}
}

fn not_found_endpoint<State: Send + Sync + 'static>(
_req: Request<State>,
fn not_found_endpoint(
_req: Request,
) -> BoxFuture<'static, crate::Result> {
Box::pin(async { Ok(Response::new(StatusCode::NotFound)) })
}

fn method_not_allowed<State: Send + Sync + 'static>(
_req: Request<State>,
fn method_not_allowed(
_req: Request,
) -> BoxFuture<'static, crate::Result> {
Box::pin(async { Ok(Response::new(StatusCode::MethodNotAllowed)) })
}
2 changes: 1 addition & 1 deletion src/security/cors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl CorsMiddleware {
}

impl<State: Send + Sync + 'static> Middleware<State> for CorsMiddleware {
fn handle<'a>(&'a self, req: Request<State>, next: Next<'a, State>) -> BoxFuture<'a, Result> {
fn handle<'a>(&'a self, req: Request, next: Next<'a, State>) -> BoxFuture<'a, Result> {
Box::pin(async move {
// TODO: how should multiple origin values be handled?
let origins = req.header(&headers::ORIGIN).cloned();
Expand Down
2 changes: 1 addition & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ impl<State> Clone for Server<State> {
impl<State: Sync + Send + 'static, InnerState: Sync + Send + 'static> Endpoint<State>
for Server<InnerState>
{
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, crate::Result> {
fn call<'a>(&'a self, req: Request) -> BoxFuture<'a, crate::Result> {
let Request {
req,
mut route_params,
Expand Down
Loading