|
| 1 | +use std::pin::Pin; |
| 2 | + |
| 3 | +use futures::{stream, Stream}; |
1 | 4 | use integration_tests::{ |
2 | 5 | pb::{test1_client, test1_server, Input1, Output1}, |
3 | 6 | trace_init, |
@@ -110,6 +113,81 @@ fn max_message_send_size() { |
110 | 113 | }); |
111 | 114 | } |
112 | 115 |
|
| 116 | +#[tokio::test] |
| 117 | +async fn response_stream_limit() { |
| 118 | + let client_blob = vec![0; 1]; |
| 119 | + |
| 120 | + let (client, server) = tokio::io::duplex(1024); |
| 121 | + |
| 122 | + struct Svc; |
| 123 | + |
| 124 | + #[tonic::async_trait] |
| 125 | + impl test1_server::Test1 for Svc { |
| 126 | + async fn unary_call(&self, _req: Request<Input1>) -> Result<Response<Output1>, Status> { |
| 127 | + unimplemented!() |
| 128 | + } |
| 129 | + |
| 130 | + type StreamCallStream = |
| 131 | + Pin<Box<dyn Stream<Item = Result<Output1, Status>> + Send + 'static>>; |
| 132 | + |
| 133 | + async fn stream_call( |
| 134 | + &self, |
| 135 | + _req: Request<Input1>, |
| 136 | + ) -> Result<Response<Self::StreamCallStream>, Status> { |
| 137 | + let blob = Output1 { |
| 138 | + buf: vec![0; 6877902], |
| 139 | + }; |
| 140 | + let stream = stream::iter(vec![Ok(blob.clone()), Ok(blob.clone())]); |
| 141 | + |
| 142 | + Ok(Response::new(Box::pin(stream))) |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + let svc = test1_server::Test1Server::new(Svc); |
| 147 | + |
| 148 | + tokio::spawn(async move { |
| 149 | + Server::builder() |
| 150 | + .add_service(svc) |
| 151 | + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(server)])) |
| 152 | + .await |
| 153 | + .unwrap(); |
| 154 | + }); |
| 155 | + |
| 156 | + // Move client to an option so we can _move_ the inner value |
| 157 | + // on the first attempt to connect. All other attempts will fail. |
| 158 | + let mut client = Some(client); |
| 159 | + let channel = Endpoint::try_from("http://[::]:50051") |
| 160 | + .unwrap() |
| 161 | + .connect_with_connector(tower::service_fn(move |_| { |
| 162 | + let client = client.take(); |
| 163 | + |
| 164 | + async move { |
| 165 | + if let Some(client) = client { |
| 166 | + Ok(client) |
| 167 | + } else { |
| 168 | + Err(std::io::Error::new( |
| 169 | + std::io::ErrorKind::Other, |
| 170 | + "Client already taken", |
| 171 | + )) |
| 172 | + } |
| 173 | + } |
| 174 | + })) |
| 175 | + .await |
| 176 | + .unwrap(); |
| 177 | + |
| 178 | + let client = test1_client::Test1Client::new(channel); |
| 179 | + |
| 180 | + let mut client = client.max_decoding_message_size(6877902 + 5); |
| 181 | + |
| 182 | + let req = Request::new(Input1 { |
| 183 | + buf: client_blob.clone(), |
| 184 | + }); |
| 185 | + |
| 186 | + let mut stream = client.stream_call(req).await.unwrap().into_inner(); |
| 187 | + |
| 188 | + while let Some(_b) = stream.message().await.unwrap() {} |
| 189 | +} |
| 190 | + |
113 | 191 | // Track caller doesn't work on async fn so we extract the async part |
114 | 192 | // into a sync version and assert the response there using track track_caller |
115 | 193 | // so that when this does panic it tells us which line in the test failed not |
@@ -210,6 +288,16 @@ async fn max_message_run(case: &TestCase) -> Result<(), Status> { |
210 | 288 | buf: self.0.clone(), |
211 | 289 | })) |
212 | 290 | } |
| 291 | + |
| 292 | + type StreamCallStream = |
| 293 | + Pin<Box<dyn Stream<Item = Result<Output1, Status>> + Send + 'static>>; |
| 294 | + |
| 295 | + async fn stream_call( |
| 296 | + &self, |
| 297 | + _req: Request<Input1>, |
| 298 | + ) -> Result<Response<Self::StreamCallStream>, Status> { |
| 299 | + unimplemented!() |
| 300 | + } |
213 | 301 | } |
214 | 302 |
|
215 | 303 | let svc = test1_server::Test1Server::new(Svc(server_blob)); |
|
0 commit comments