Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions tonic-web/src/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,16 @@ impl<B> GrpcWebCall<B> {
}
}

impl<B> GrpcWebCall<B>
impl<B, D> GrpcWebCall<B>
where
B: Body<Data = Bytes>,
B: Body<Data = D>,
B::Error: Error,
D: Buf,
{
fn poll_decode(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<B::Data, Status>>> {
) -> Poll<Option<Result<Bytes, Status>>> {
match self.encoding {
Encoding::Base64 => loop {
if let Some(bytes) = self.as_mut().decode_chunk()? {
Expand All @@ -186,7 +187,10 @@ where
},

Encoding::None => match ready!(self.project().inner.poll_data(cx)) {
Some(res) => Poll::Ready(Some(res.map_err(internal_error))),
Some(res) => Poll::Ready(Some(
res.map(|mut d| d.copy_to_bytes(d.remaining()))
.map_err(internal_error),
)),
None => Poll::Ready(None),
},
}
Expand All @@ -195,15 +199,18 @@ where
fn poll_encode(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<B::Data, Status>>> {
) -> Poll<Option<Result<Bytes, Status>>> {
let mut this = self.as_mut().project();

if let Some(mut res) = ready!(this.inner.as_mut().poll_data(cx)) {
if *this.encoding == Encoding::Base64 {
res = res.map(|b| crate::util::base64::STANDARD.encode(b).into())
}
if let Some(res) = ready!(this.inner.as_mut().poll_data(cx)) {
let res = res.map(|mut d| d.copy_to_bytes(d.remaining()));
let bytes = if *this.encoding == Encoding::Base64 {
res.map(|b| crate::util::base64::STANDARD.encode(b).into())
} else {
res
};

return Poll::Ready(Some(res.map_err(internal_error)));
return Poll::Ready(Some(bytes.map_err(internal_error)));
}

// this flag is needed because the inner stream never
Expand All @@ -229,10 +236,11 @@ where
}
}

impl<B> Body for GrpcWebCall<B>
impl<B, D> Body for GrpcWebCall<B>
where
B: Body<Data = Bytes>,
B: Body<Data = D>,
B::Error: Error,
D: Buf,
{
type Data = Bytes;
type Error = Status;
Expand Down
35 changes: 24 additions & 11 deletions tonic-web/src/layer.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,48 @@
use std::error::Error;

use super::{BoxBody, BoxError, GrpcWebService};

use tower_layer::Layer;
use tower_service::Service;

/// Layer implementing the grpc-web protocol.
#[derive(Debug, Clone)]
pub struct GrpcWebLayer {
_priv: (),
#[derive(Debug)]
pub struct GrpcWebLayer<ResBody = BoxBody> {
_markers: std::marker::PhantomData<ResBody>,
}

impl<ResBody> Clone for GrpcWebLayer<ResBody> {
fn clone(&self) -> Self {
Self {
_markers: std::marker::PhantomData,
}
}
}

impl GrpcWebLayer {
impl<ResBody> GrpcWebLayer<ResBody> {
/// Create a new grpc-web layer.
pub fn new() -> GrpcWebLayer {
Self { _priv: () }
pub fn new() -> Self {
Self {
_markers: std::marker::PhantomData,
}
}
}

impl Default for GrpcWebLayer {
impl<ResBody> Default for GrpcWebLayer<ResBody> {
fn default() -> Self {
Self::new()
}
}

impl<S> Layer<S> for GrpcWebLayer
impl<S, ResBody> Layer<S> for GrpcWebLayer<ResBody>
where
S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>>,
S: Send + 'static,
S: Service<http::Request<BoxBody>, Response = http::Response<ResBody>> + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxError> + Send,
ResBody: http_body::Body + Send + 'static,
ResBody::Error: Error + Send + 'static,
{
type Service = GrpcWebService<S>;
type Service = GrpcWebService<S, ResBody>;

fn layer(&self, inner: S) -> Self::Service {
GrpcWebService::new(inner)
Expand Down
61 changes: 43 additions & 18 deletions tonic-web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,30 +107,34 @@ mod client;
mod layer;
mod service;

use bytes::Buf;
use http::header::HeaderName;
use std::time::Duration;
use tonic::{body::BoxBody, server::NamedService};
use tower_http::cors::{AllowOrigin, CorsLayer};
use http_body::Body;
use std::{error::Error, time::Duration};
use tonic::{body::BoxBody, server::NamedService, Status};
use tower_http::cors::{AllowOrigin, Cors, CorsLayer};
use tower_layer::Layer;
use tower_service::Service;

/// Alias for a type-erased error type.
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;

const DEFAULT_MAX_AGE: Duration = Duration::from_secs(24 * 60 * 60);
const DEFAULT_EXPOSED_HEADERS: [&str; 3] =
["grpc-status", "grpc-message", "grpc-status-details-bin"];
const DEFAULT_ALLOW_HEADERS: [&str; 4] =
["x-grpc-web", "content-type", "x-user-agent", "grpc-timeout"];

type BoxError = Box<dyn std::error::Error + Send + Sync>;

/// Enable a tonic service to handle grpc-web requests with the default configuration.
///
/// You can customize the CORS configuration composing the [`GrpcWebLayer`] with the cors layer of your choice.
pub fn enable<S>(service: S) -> CorsGrpcWeb<S>
pub fn enable<S, ResBody>(service: S) -> CorsGrpcWeb<S, ResBody>
where
S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>>,
S: Service<http::Request<BoxBody>, Response = http::Response<ResBody>>,
S: Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxError> + Send,
ResBody: Body,
{
let cors = CorsLayer::new()
.allow_origin(AllowOrigin::mirror_request())
Expand All @@ -156,34 +160,46 @@ where

/// A newtype wrapper around [`GrpcWebLayer`] and [`tower_http::cors::CorsLayer`] to allow
/// `tonic_web::enable` to implement the [`NamedService`] trait.
#[derive(Debug, Clone)]
pub struct CorsGrpcWeb<S>(tower_http::cors::Cors<GrpcWebService<S>>);
#[derive(Debug)]
pub struct CorsGrpcWeb<S, ResBody = BoxBody>(Cors<GrpcWebService<S, ResBody>>);

impl<S: Clone, ResBody> Clone for CorsGrpcWeb<S, ResBody> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl<S> Service<http::Request<hyper::Body>> for CorsGrpcWeb<S>
impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for CorsGrpcWeb<S, ResBody>
where
S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>>,
S: Service<http::Request<BoxBody>, Response = http::Response<ResBody>>,
S: Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxError> + Send,
ReqBody: Body + Send + 'static,
ReqBody::Error: Error + Send + Sync,
ResBody: Body + Default + Send + 'static,
ResBody::Error: Error + Send + Sync + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future =
<tower_http::cors::Cors<GrpcWebService<S>> as Service<http::Request<hyper::Body>>>::Future;
type Response = <Cors<GrpcWebService<S, ResBody>> as Service<http::Request<ReqBody>>>::Response;
type Error = <Cors<GrpcWebService<S, ResBody>> as Service<http::Request<ReqBody>>>::Error;
type Future = <Cors<GrpcWebService<S, ResBody>> as Service<http::Request<ReqBody>>>::Future;

fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
<Cors<GrpcWebService<S, ResBody>> as Service<http::Request<ReqBody>>>::poll_ready(
&mut self.0,
cx,
)
}

fn call(&mut self, req: http::Request<hyper::Body>) -> Self::Future {
fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
self.0.call(req)
}
}

impl<S> NamedService for CorsGrpcWeb<S>
impl<S, ResBody> NamedService for CorsGrpcWeb<S, ResBody>
where
S: NamedService,
{
Expand All @@ -208,3 +224,12 @@ pub(crate) mod util {
);
}
}

pub(crate) fn box_body<D: Buf, E: Into<BoxError> + Send>(
body: impl Body<Data = D, Error = E> + Send + 'static,
) -> BoxBody {
let bod = body
.map_data(|mut d| d.copy_to_bytes(d.remaining()))
.map_err(|e| Status::from_error(e.into() as BoxError));
bod.boxed_unsync()
}
Loading