Skip to content

Commit 7a6b20d

Browse files
feat: add GrpcMethod extension into request for client (#1275)
* feat: add GrpcMethod extension into request for client * refactor: change GrpcMethod fields into private and expose methods instead * refactor: hide GrpcMethod::new in doc --------- Co-authored-by: Lucio Franco <luciofranco14@gmail.com>
1 parent 1547f96 commit 7a6b20d

File tree

9 files changed

+207
-85
lines changed

9 files changed

+207
-85
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
use std::time::Duration;
2+
3+
use futures::{channel::oneshot, FutureExt};
4+
use integration_tests::pb::{test_client::TestClient, test_server, Input, Output};
5+
use tonic::{
6+
transport::{Endpoint, Server},
7+
GrpcMethod, Request, Response, Status,
8+
};
9+
10+
#[tokio::test]
11+
async fn interceptor_retrieves_grpc_method() {
12+
use test_server::Test;
13+
14+
struct Svc;
15+
16+
#[tonic::async_trait]
17+
impl Test for Svc {
18+
async fn unary_call(&self, _: Request<Input>) -> Result<Response<Output>, Status> {
19+
Ok(Response::new(Output {}))
20+
}
21+
}
22+
23+
let svc = test_server::TestServer::new(Svc);
24+
25+
let (tx, rx) = oneshot::channel();
26+
// Start the server now, second call should succeed
27+
let jh = tokio::spawn(async move {
28+
Server::builder()
29+
.add_service(svc)
30+
.serve_with_shutdown("127.0.0.1:1340".parse().unwrap(), rx.map(drop))
31+
.await
32+
.unwrap();
33+
});
34+
35+
let channel = Endpoint::from_static("http://127.0.0.1:1340").connect_lazy();
36+
37+
fn client_intercept(req: Request<()>) -> Result<Request<()>, Status> {
38+
println!("Intercepting client request: {:?}", req);
39+
40+
let gm = req.extensions().get::<GrpcMethod>().unwrap();
41+
assert_eq!(gm.service(), "test.Test");
42+
assert_eq!(gm.method(), "UnaryCall");
43+
44+
Ok(req)
45+
}
46+
let mut client = TestClient::with_interceptor(channel, client_intercept);
47+
48+
tokio::time::sleep(Duration::from_millis(100)).await;
49+
client.unary_call(Request::new(Input {})).await.unwrap();
50+
51+
tx.send(()).unwrap();
52+
jh.await.unwrap();
53+
}

tonic-build/src/client.rs

Lines changed: 76 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use std::collections::HashSet;
22

33
use super::{Attributes, Method, Service};
4-
use crate::{format_method_name, generate_doc_comments, naive_snake_case};
4+
use crate::{
5+
format_method_name, format_method_path, format_service_name, generate_doc_comments,
6+
naive_snake_case,
7+
};
58
use proc_macro2::TokenStream;
69
use quote::{format_ident, quote};
710

@@ -51,21 +54,16 @@ pub(crate) fn generate_internal<T: Service>(
5154
let connect = generate_connect(&service_ident, build_transport);
5255

5356
let package = if emit_package { service.package() } else { "" };
54-
let path = format!(
55-
"{}{}{}",
56-
package,
57-
if package.is_empty() { "" } else { "." },
58-
service.identifier()
59-
);
57+
let service_name = format_service_name(service, emit_package);
6058

61-
let service_doc = if disable_comments.contains(&path) {
59+
let service_doc = if disable_comments.contains(&service_name) {
6260
TokenStream::new()
6361
} else {
6462
generate_doc_comments(service.comment())
6563
};
6664

6765
let mod_attributes = attributes.for_mod(package);
68-
let struct_attributes = attributes.for_struct(&path);
66+
let struct_attributes = attributes.for_struct(&service_name);
6967

7068
quote! {
7169
/// Generated client implementations.
@@ -193,30 +191,41 @@ fn generate_methods<T: Service>(
193191
disable_comments: &HashSet<String>,
194192
) -> TokenStream {
195193
let mut stream = TokenStream::new();
196-
let package = if emit_package { service.package() } else { "" };
197194

198195
for method in service.methods() {
199-
let path = format!(
200-
"/{}{}{}/{}",
201-
package,
202-
if package.is_empty() { "" } else { "." },
203-
service.identifier(),
204-
method.identifier()
205-
);
206-
207-
if !disable_comments.contains(&format_method_name(package, service, method)) {
196+
if !disable_comments.contains(&format_method_name(service, method, emit_package)) {
208197
stream.extend(generate_doc_comments(method.comment()));
209198
}
210199

211200
let method = match (method.client_streaming(), method.server_streaming()) {
212-
(false, false) => generate_unary(method, proto_path, compile_well_known_types, path),
213-
(false, true) => {
214-
generate_server_streaming(method, proto_path, compile_well_known_types, path)
215-
}
216-
(true, false) => {
217-
generate_client_streaming(method, proto_path, compile_well_known_types, path)
218-
}
219-
(true, true) => generate_streaming(method, proto_path, compile_well_known_types, path),
201+
(false, false) => generate_unary(
202+
service,
203+
method,
204+
emit_package,
205+
proto_path,
206+
compile_well_known_types,
207+
),
208+
(false, true) => generate_server_streaming(
209+
service,
210+
method,
211+
emit_package,
212+
proto_path,
213+
compile_well_known_types,
214+
),
215+
(true, false) => generate_client_streaming(
216+
service,
217+
method,
218+
emit_package,
219+
proto_path,
220+
compile_well_known_types,
221+
),
222+
(true, true) => generate_streaming(
223+
service,
224+
method,
225+
emit_package,
226+
proto_path,
227+
compile_well_known_types,
228+
),
220229
};
221230

222231
stream.extend(method);
@@ -225,15 +234,19 @@ fn generate_methods<T: Service>(
225234
stream
226235
}
227236

228-
fn generate_unary<T: Method>(
229-
method: &T,
237+
fn generate_unary<T: Service>(
238+
service: &T,
239+
method: &T::Method,
240+
emit_package: bool,
230241
proto_path: &str,
231242
compile_well_known_types: bool,
232-
path: String,
233243
) -> TokenStream {
234244
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
235245
let ident = format_ident!("{}", method.name());
236246
let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
247+
let service_name = format_service_name(service, emit_package);
248+
let path = format_method_path(service, method, emit_package);
249+
let method_name = method.identifier();
237250

238251
quote! {
239252
pub async fn #ident(
@@ -245,21 +258,26 @@ fn generate_unary<T: Method>(
245258
})?;
246259
let codec = #codec_name::default();
247260
let path = http::uri::PathAndQuery::from_static(#path);
248-
self.inner.unary(request.into_request(), path, codec).await
261+
let mut req = request.into_request();
262+
req.extensions_mut().insert(GrpcMethod::new(#service_name, #method_name));
263+
self.inner.unary(req, path, codec).await
249264
}
250265
}
251266
}
252267

253-
fn generate_server_streaming<T: Method>(
254-
method: &T,
268+
fn generate_server_streaming<T: Service>(
269+
service: &T,
270+
method: &T::Method,
271+
emit_package: bool,
255272
proto_path: &str,
256273
compile_well_known_types: bool,
257-
path: String,
258274
) -> TokenStream {
259275
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
260276
let ident = format_ident!("{}", method.name());
261-
262277
let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
278+
let service_name = format_service_name(service, emit_package);
279+
let path = format_method_path(service, method, emit_package);
280+
let method_name = method.identifier();
263281

264282
quote! {
265283
pub async fn #ident(
@@ -271,21 +289,26 @@ fn generate_server_streaming<T: Method>(
271289
})?;
272290
let codec = #codec_name::default();
273291
let path = http::uri::PathAndQuery::from_static(#path);
274-
self.inner.server_streaming(request.into_request(), path, codec).await
292+
let mut req = request.into_request();
293+
req.extensions_mut().insert(GrpcMethod::new(#service_name, #method_name));
294+
self.inner.server_streaming(req, path, codec).await
275295
}
276296
}
277297
}
278298

279-
fn generate_client_streaming<T: Method>(
280-
method: &T,
299+
fn generate_client_streaming<T: Service>(
300+
service: &T,
301+
method: &T::Method,
302+
emit_package: bool,
281303
proto_path: &str,
282304
compile_well_known_types: bool,
283-
path: String,
284305
) -> TokenStream {
285306
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
286307
let ident = format_ident!("{}", method.name());
287-
288308
let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
309+
let service_name = format_service_name(service, emit_package);
310+
let path = format_method_path(service, method, emit_package);
311+
let method_name = method.identifier();
289312

290313
quote! {
291314
pub async fn #ident(
@@ -297,21 +320,26 @@ fn generate_client_streaming<T: Method>(
297320
})?;
298321
let codec = #codec_name::default();
299322
let path = http::uri::PathAndQuery::from_static(#path);
300-
self.inner.client_streaming(request.into_streaming_request(), path, codec).await
323+
let mut req = request.into_streaming_request();
324+
req.extensions_mut().insert(GrpcMethod::new(#service_name, #method_name));
325+
self.inner.client_streaming(req, path, codec).await
301326
}
302327
}
303328
}
304329

305-
fn generate_streaming<T: Method>(
306-
method: &T,
330+
fn generate_streaming<T: Service>(
331+
service: &T,
332+
method: &T::Method,
333+
emit_package: bool,
307334
proto_path: &str,
308335
compile_well_known_types: bool,
309-
path: String,
310336
) -> TokenStream {
311337
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
312338
let ident = format_ident!("{}", method.name());
313-
314339
let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
340+
let service_name = format_service_name(service, emit_package);
341+
let path = format_method_path(service, method, emit_package);
342+
let method_name = method.identifier();
315343

316344
quote! {
317345
pub async fn #ident(
@@ -323,7 +351,9 @@ fn generate_streaming<T: Method>(
323351
})?;
324352
let codec = #codec_name::default();
325353
let path = http::uri::PathAndQuery::from_static(#path);
326-
self.inner.streaming(request.into_streaming_request(), path, codec).await
354+
let mut req = request.into_streaming_request();
355+
req.extensions_mut().insert(GrpcMethod::new(#service_name,#method_name));
356+
self.inner.streaming(req, path, codec).await
327357
}
328358
}
329359
}

tonic-build/src/lib.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,16 +197,28 @@ impl Attributes {
197197
}
198198
}
199199

200-
fn format_method_name<T: Service>(
201-
package: &str,
202-
service: &T,
203-
method: &<T as Service>::Method,
204-
) -> String {
200+
fn format_service_name<T: Service>(service: &T, emit_package: bool) -> String {
201+
let package = if emit_package { service.package() } else { "" };
205202
format!(
206-
"{}{}{}.{}",
203+
"{}{}{}",
207204
package,
208205
if package.is_empty() { "" } else { "." },
209206
service.identifier(),
207+
)
208+
}
209+
210+
fn format_method_path<T: Service>(service: &T, method: &T::Method, emit_package: bool) -> String {
211+
format!(
212+
"/{}/{}",
213+
format_service_name(service, emit_package),
214+
method.identifier()
215+
)
216+
}
217+
218+
fn format_method_name<T: Service>(service: &T, method: &T::Method, emit_package: bool) -> String {
219+
format!(
220+
"{}.{}",
221+
format_service_name(service, emit_package),
210222
method.identifier()
211223
)
212224
}

0 commit comments

Comments
 (0)