Skip to content

Commit 1abb8ef

Browse files
committed
tonic-web: proxy any kind of Service
This allows applying the GrpcWebLayer to any kind of Service, not just ones that tonic generates. This makes it possible to use tonic-web as a grpc-web proxy to a gRPC implemented in another language for example.
1 parent 65d909e commit 1abb8ef

File tree

8 files changed

+330
-144
lines changed

8 files changed

+330
-144
lines changed

tonic-web/src/call.rs

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::error::Error;
21
use std::pin::Pin;
32
use std::task::{Context, Poll};
43

@@ -9,6 +8,7 @@ use http::{header, HeaderMap, HeaderValue};
98
use http_body::{Body, SizeHint};
109
use pin_project::pin_project;
1110
use tonic::Status;
11+
use tower_http::BoxError;
1212

1313
use self::content_types::*;
1414

@@ -62,7 +62,7 @@ pub(crate) struct GrpcWebCall<B> {
6262
poll_trailers: bool,
6363
}
6464

65-
impl<B> GrpcWebCall<B> {
65+
impl<B: Body> GrpcWebCall<B> {
6666
pub(crate) fn request(inner: B, encoding: Encoding) -> Self {
6767
Self::new(inner, Direction::Request, encoding)
6868
}
@@ -108,15 +108,16 @@ impl<B> GrpcWebCall<B> {
108108
}
109109
}
110110

111-
impl<B> GrpcWebCall<B>
111+
impl<B, D> GrpcWebCall<B>
112112
where
113-
B: Body<Data = Bytes>,
114-
B::Error: Error,
113+
B: Body<Data = D>,
114+
B::Error: Into<BoxError> + Send + 'static,
115+
D: Buf,
115116
{
116117
fn poll_decode(
117118
mut self: Pin<&mut Self>,
118119
cx: &mut Context<'_>,
119-
) -> Poll<Option<Result<B::Data, Status>>> {
120+
) -> Poll<Option<Result<Bytes, Status>>> {
120121
match self.encoding {
121122
Encoding::Base64 => loop {
122123
if let Some(bytes) = self.as_mut().decode_chunk()? {
@@ -139,7 +140,10 @@ where
139140
},
140141

141142
Encoding::None => match ready!(self.project().inner.poll_data(cx)) {
142-
Some(res) => Poll::Ready(Some(res.map_err(internal_error))),
143+
Some(res) => Poll::Ready(Some(
144+
res.map(|mut d| d.copy_to_bytes(d.remaining()))
145+
.map_err(internal_error),
146+
)),
143147
None => Poll::Ready(None),
144148
},
145149
}
@@ -148,15 +152,20 @@ where
148152
fn poll_encode(
149153
mut self: Pin<&mut Self>,
150154
cx: &mut Context<'_>,
151-
) -> Poll<Option<Result<B::Data, Status>>> {
155+
) -> Poll<Option<Result<Bytes, Status>>> {
152156
let mut this = self.as_mut().project();
153157

154-
if let Some(mut res) = ready!(this.inner.as_mut().poll_data(cx)) {
158+
if let Some(res) = ready!(this.inner.as_mut().poll_data(cx)) {
159+
let res = res.map(|mut d| d.copy_to_bytes(d.remaining()));
160+
161+
let bytes: Result<Bytes, <B as Body>::Error>;
155162
if *this.encoding == Encoding::Base64 {
156-
res = res.map(|b| crate::util::base64::STANDARD.encode(b).into())
163+
bytes = res.map(|b| crate::util::base64::STANDARD.encode(b).into());
164+
} else {
165+
bytes = res;
157166
}
158167

159-
return Poll::Ready(Some(res.map_err(internal_error)));
168+
return Poll::Ready(Some(bytes.map_err(internal_error)));
160169
}
161170

162171
// this flag is needed because the inner stream never
@@ -182,10 +191,11 @@ where
182191
}
183192
}
184193

185-
impl<B> Body for GrpcWebCall<B>
194+
impl<B, D> Body for GrpcWebCall<B>
186195
where
187-
B: Body<Data = Bytes>,
188-
B::Error: Error,
196+
B: Body<Data = D>,
197+
B::Error: Into<BoxError> + Send + 'static,
198+
D: Buf,
189199
{
190200
type Data = Bytes;
191201
type Error = Status;
@@ -216,10 +226,11 @@ where
216226
}
217227
}
218228

219-
impl<B> Stream for GrpcWebCall<B>
229+
impl<B, D> Stream for GrpcWebCall<B>
220230
where
221-
B: Body<Data = Bytes>,
222-
B::Error: Error,
231+
B: Body<Data = D>,
232+
B::Error: Into<BoxError> + Send + 'static,
233+
D: Buf,
223234
{
224235
type Item = Result<Bytes, Status>;
225236

@@ -252,7 +263,8 @@ impl Encoding {
252263
}
253264
}
254265

255-
fn internal_error(e: impl std::fmt::Display) -> Status {
266+
fn internal_error(e: impl Into<BoxError>) -> Status {
267+
let e: BoxError = e.into();
256268
Status::internal(format!("tonic-web: {}", e))
257269
}
258270

tonic-web/src/layer.rs

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,43 @@ use tower_layer::Layer;
44
use tower_service::Service;
55

66
/// Layer implementing the grpc-web protocol.
7-
#[derive(Debug, Clone)]
8-
pub struct GrpcWebLayer {
9-
_priv: (),
7+
#[derive(Debug)]
8+
pub struct GrpcWebLayer<RespBody> {
9+
_markers: std::marker::PhantomData<RespBody>,
1010
}
1111

12-
impl GrpcWebLayer {
12+
impl<RespBody> Clone for GrpcWebLayer<RespBody> {
13+
fn clone(&self) -> Self {
14+
Self {
15+
_markers: std::marker::PhantomData,
16+
}
17+
}
18+
}
19+
20+
impl<RespBody> GrpcWebLayer<RespBody> {
1321
/// Create a new grpc-web layer.
14-
pub fn new() -> GrpcWebLayer {
15-
Self { _priv: () }
22+
pub fn new() -> Self {
23+
Self {
24+
_markers: std::marker::PhantomData,
25+
}
1626
}
1727
}
1828

19-
impl Default for GrpcWebLayer {
29+
impl<RespBody> Default for GrpcWebLayer<RespBody> {
2030
fn default() -> Self {
2131
Self::new()
2232
}
2333
}
2434

25-
impl<S> Layer<S> for GrpcWebLayer
35+
impl<S, RespBody> Layer<S> for GrpcWebLayer<RespBody>
2636
where
27-
S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>>,
28-
S: Send + 'static,
37+
S: Service<http::Request<BoxBody>, Response = http::Response<RespBody>> + Send + 'static,
2938
S::Future: Send + 'static,
3039
S::Error: Into<BoxError> + Send,
40+
RespBody: http_body::Body + Send + 'static,
41+
RespBody::Error: Into<BoxError> + Send + 'static,
3142
{
32-
type Service = GrpcWebService<S>;
43+
type Service = GrpcWebService<S, RespBody>;
3344

3445
fn layer(&self, inner: S) -> Self::Service {
3546
GrpcWebService::new(inner)

tonic-web/src/lib.rs

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@
9797
#![doc(html_root_url = "https://docs.rs/tonic-web/0.9.1")]
9898
#![doc(issue_tracker_base_url = "https://github.com/hyperium/tonic/issues/")]
9999

100+
use bytes::Buf;
101+
use http_body::Body;
100102
pub use layer::GrpcWebLayer;
101103
pub use service::{GrpcWebService, ResponseFuture};
102104

@@ -106,8 +108,11 @@ mod service;
106108

107109
use http::header::HeaderName;
108110
use std::time::Duration;
109-
use tonic::{body::BoxBody, server::NamedService};
110-
use tower_http::cors::{AllowOrigin, CorsLayer};
111+
use tonic::{body::BoxBody, server::NamedService, Status};
112+
use tower_http::{
113+
cors::{AllowOrigin, CorsLayer},
114+
BoxError,
115+
};
111116
use tower_layer::Layer;
112117
use tower_service::Service;
113118

@@ -117,17 +122,16 @@ const DEFAULT_EXPOSED_HEADERS: [&str; 3] =
117122
const DEFAULT_ALLOW_HEADERS: [&str; 4] =
118123
["x-grpc-web", "content-type", "x-user-agent", "grpc-timeout"];
119124

120-
type BoxError = Box<dyn std::error::Error + Send + Sync>;
121-
122125
/// Enable a tonic service to handle grpc-web requests with the default configuration.
123126
///
124127
/// You can customize the CORS configuration composing the [`GrpcWebLayer`] with the cors layer of your choice.
125-
pub fn enable<S>(service: S) -> CorsGrpcWeb<S>
128+
pub fn enable<S, RespBody>(service: S) -> CorsGrpcWeb<S, RespBody>
126129
where
127-
S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>>,
130+
S: Service<http::Request<BoxBody>, Response = http::Response<RespBody>>,
128131
S: Clone + Send + 'static,
129132
S::Future: Send + 'static,
130133
S::Error: Into<BoxError> + Send,
134+
RespBody: Body,
131135
{
132136
let cors = CorsLayer::new()
133137
.allow_origin(AllowOrigin::mirror_request())
@@ -153,34 +157,52 @@ where
153157

154158
/// A newtype wrapper around [`GrpcWebLayer`] and [`tower_http::cors::CorsLayer`] to allow
155159
/// `tonic_web::enable` to implement the [`NamedService`] trait.
156-
#[derive(Debug, Clone)]
157-
pub struct CorsGrpcWeb<S>(tower_http::cors::Cors<GrpcWebService<S>>);
160+
#[derive(Debug)]
161+
pub struct CorsGrpcWeb<S, RespBody>(tower_http::cors::Cors<GrpcWebService<S, RespBody>>);
162+
163+
impl<S: Clone, RespBody> Clone for CorsGrpcWeb<S, RespBody> {
164+
fn clone(&self) -> Self {
165+
Self(self.0.clone())
166+
}
167+
}
158168

159-
impl<S> Service<http::Request<hyper::Body>> for CorsGrpcWeb<S>
169+
impl<S, ReqBody, RespBody> Service<http::Request<ReqBody>> for CorsGrpcWeb<S, RespBody>
160170
where
161-
S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>>,
171+
S: Service<http::Request<BoxBody>, Response = http::Response<RespBody>>,
162172
S: Clone + Send + 'static,
163173
S::Future: Send + 'static,
164174
S::Error: Into<BoxError> + Send,
175+
ReqBody: Body + Send + 'static,
176+
ReqBody::Error: Into<BoxError> + Send,
177+
RespBody: Body + Default + Send + 'static,
178+
RespBody::Error: Into<BoxError> + Send + 'static,
165179
{
166-
type Response = S::Response;
167-
type Error = S::Error;
168-
type Future =
169-
<tower_http::cors::Cors<GrpcWebService<S>> as Service<http::Request<hyper::Body>>>::Future;
180+
type Response = <tower_http::cors::Cors<GrpcWebService<S, RespBody>> as Service<
181+
http::Request<ReqBody>,
182+
>>::Response;
183+
type Error = <tower_http::cors::Cors<GrpcWebService<S, RespBody>> as Service<
184+
http::Request<ReqBody>,
185+
>>::Error;
186+
type Future = <tower_http::cors::Cors<GrpcWebService<S, RespBody>> as Service<
187+
http::Request<ReqBody>,
188+
>>::Future;
170189

171190
fn poll_ready(
172191
&mut self,
173192
cx: &mut std::task::Context<'_>,
174193
) -> std::task::Poll<Result<(), Self::Error>> {
175-
self.0.poll_ready(cx)
194+
<tower_http::cors::Cors<GrpcWebService<S, RespBody>> as Service<http::Request<ReqBody>>>::poll_ready(
195+
&mut self.0,
196+
cx,
197+
)
176198
}
177199

178-
fn call(&mut self, req: http::Request<hyper::Body>) -> Self::Future {
200+
fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
179201
self.0.call(req)
180202
}
181203
}
182204

183-
impl<S> NamedService for CorsGrpcWeb<S>
205+
impl<S, RespBody> NamedService for CorsGrpcWeb<S, RespBody>
184206
where
185207
S: NamedService,
186208
{
@@ -205,3 +227,12 @@ pub(crate) mod util {
205227
);
206228
}
207229
}
230+
231+
pub(crate) fn box_body<D: Buf, E: Into<BoxError> + Send>(
232+
body: impl Body<Data = D, Error = E> + Send + 'static,
233+
) -> BoxBody {
234+
let bod = body
235+
.map_data(|mut d| d.copy_to_bytes(d.remaining()))
236+
.map_err(|e| Status::from_error(e.into() as BoxError));
237+
bod.boxed_unsync()
238+
}

0 commit comments

Comments
 (0)