Skip to content

Commit ed3e7e9

Browse files
authored
fix(codec): Fix buffer decode panic on full (#43)
* fix(codec): Fix buffer decode panic on full This is a naive fix for the buffer growing beyond capacity and producing a panic. Ideally we should do a better job of not having to allocate for new messages by using a link list. * fmt
1 parent bd2b4e0 commit ed3e7e9

File tree

5 files changed

+134
-5
lines changed

5 files changed

+134
-5
lines changed

tonic/src/codec/decode.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ use std::{
1212
};
1313
use tracing::{debug, trace};
1414

15+
const BUFFER_SIZE: usize = 8 * 1024;
16+
1517
/// Streaming requests and responses.
1618
///
1719
/// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface
@@ -70,6 +72,7 @@ impl<T> Streaming<T> {
7072
{
7173
Self::new(decoder, body, Direction::Request)
7274
}
75+
7376
fn new<B, D>(decoder: D, body: B, direction: Direction) -> Self
7477
where
7578
B: Body + Send + 'static,
@@ -82,8 +85,7 @@ impl<T> Streaming<T> {
8285
body: BoxBody::map_from(body),
8386
state: State::ReadHeader,
8487
direction,
85-
// FIXME: update this with a reasonable size
86-
buf: BytesMut::with_capacity(1024 * 1024),
88+
buf: BytesMut::with_capacity(BUFFER_SIZE),
8789
trailers: None,
8890
}
8991
}
@@ -234,6 +236,16 @@ impl<T> Stream for Streaming<T> {
234236
};
235237

236238
if let Some(data) = chunk {
239+
if data.remaining() > self.buf.remaining_mut() {
240+
let amt = if data.remaining() > BUFFER_SIZE {
241+
data.remaining()
242+
} else {
243+
BUFFER_SIZE
244+
};
245+
246+
self.buf.reserve(amt);
247+
}
248+
237249
self.buf.put(data);
238250
} else {
239251
// FIXME: improve buf usage.

tonic/src/codec/encode.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use std::pin::Pin;
99
use std::task::{Context, Poll};
1010
use tokio_codec::Encoder;
1111

12+
const BUFFER_SIZE: usize = 8 * 1024;
13+
1214
pub(crate) fn encode_server<T, U>(
1315
encoder: T,
1416
source: U,
@@ -39,7 +41,7 @@ where
3941
U: Stream<Item = Result<T::Item, Status>>,
4042
{
4143
async_stream::stream! {
42-
let mut buf = BytesMut::with_capacity(1024 * 1024);
44+
let mut buf = BytesMut::with_capacity(BUFFER_SIZE);
4345
futures_util::pin_mut!(source);
4446

4547
loop {

tonic/src/codec/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ mod encode;
88
#[cfg(feature = "prost")]
99
mod prost;
1010

11+
#[cfg(test)]
12+
mod tests;
13+
1114
pub use self::decode::Streaming;
1215
pub(crate) use self::encode::{encode_client, encode_server};
1316
#[cfg(feature = "prost")]

tonic/src/codec/prost.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ where
4040
}
4141

4242
/// A [`Encoder`] that knows how to encode `T`.
43-
#[derive(Debug, Clone)]
43+
#[derive(Debug, Clone, Default)]
4444
pub struct ProstEncoder<T>(PhantomData<T>);
4545

4646
impl<T: Message> Encoder for ProstEncoder<T> {
@@ -60,7 +60,7 @@ impl<T: Message> Encoder for ProstEncoder<T> {
6060
}
6161

6262
/// A [`Decoder`] that knows how to decode `U`.
63-
#[derive(Debug, Clone)]
63+
#[derive(Debug, Clone, Default)]
6464
pub struct ProstDecoder<U>(PhantomData<U>);
6565

6666
impl<U: Message + Default> Decoder for ProstDecoder<U> {

tonic/src/codec/tests.rs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
use super::{
2+
encode_server,
3+
prost::{ProstDecoder, ProstEncoder},
4+
Streaming,
5+
};
6+
use crate::Status;
7+
use bytes::{Buf, BufMut, Bytes, BytesMut, IntoBuf};
8+
use http_body::Body;
9+
use prost::Message;
10+
use std::{
11+
io::Cursor,
12+
pin::Pin,
13+
task::{Context, Poll},
14+
};
15+
16+
#[derive(Clone, PartialEq, prost::Message)]
17+
struct Msg {
18+
#[prost(bytes, tag = "1")]
19+
data: Vec<u8>,
20+
}
21+
22+
#[tokio::test]
23+
async fn decode() {
24+
let decoder = ProstDecoder::<Msg>::default();
25+
26+
let data = Vec::from(&[0u8; 1024][..]);
27+
let msg = Msg { data };
28+
29+
let mut buf = BytesMut::new();
30+
let len = msg.encoded_len();
31+
32+
buf.reserve(len + 5);
33+
buf.put_u8(0);
34+
buf.put_u32_be(len as u32);
35+
msg.encode(&mut buf).unwrap();
36+
37+
let body = MockBody(buf.freeze(), 0, 100);
38+
39+
let mut stream = Streaming::new_request(decoder, body);
40+
41+
while let Some(_) = stream.message().await.unwrap() {}
42+
}
43+
44+
#[tokio::test]
45+
async fn encode() {
46+
let encoder = ProstEncoder::<Msg>::default();
47+
48+
let data = Vec::from(&[0u8; 1024][..]);
49+
let msg = Msg { data };
50+
51+
let messages = std::iter::repeat(Ok::<_, Status>(msg)).take(10000);
52+
let source = futures_util::stream::iter(messages);
53+
54+
let body = encode_server(encoder, source);
55+
56+
futures_util::pin_mut!(body);
57+
58+
while let Some(r) = body.next().await {
59+
r.unwrap();
60+
}
61+
}
62+
63+
#[derive(Debug)]
64+
struct MockBody(Bytes, usize, usize);
65+
66+
impl Body for MockBody {
67+
type Data = Data;
68+
type Error = Status;
69+
70+
fn poll_data(
71+
mut self: Pin<&mut Self>,
72+
_cx: &mut Context<'_>,
73+
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
74+
if self.1 > self.2 {
75+
self.1 += 1;
76+
let data = Data(self.0.clone().into_buf());
77+
Poll::Ready(Some(Ok(data)))
78+
} else {
79+
Poll::Ready(None)
80+
}
81+
}
82+
83+
fn poll_trailers(
84+
self: Pin<&mut Self>,
85+
cx: &mut Context<'_>,
86+
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
87+
drop(cx);
88+
Poll::Ready(Ok(None))
89+
}
90+
}
91+
92+
struct Data(Cursor<Bytes>);
93+
94+
impl Into<Bytes> for Data {
95+
fn into(self) -> Bytes {
96+
self.0.into_inner()
97+
}
98+
}
99+
100+
impl Buf for Data {
101+
fn remaining(&self) -> usize {
102+
self.0.remaining()
103+
}
104+
105+
fn bytes(&self) -> &[u8] {
106+
self.0.bytes()
107+
}
108+
109+
fn advance(&mut self, cnt: usize) {
110+
self.0.advance(cnt)
111+
}
112+
}

0 commit comments

Comments
 (0)