Skip to content

Commit

Permalink
fix(codec): Fix buffer decode panic on full (#43)
Browse files Browse the repository at this point in the history
* fix(codec): Fix buffer decode panic on full

This is a naive fix for the buffer growing beyond capacity
and producing a panic. Ideally we should do a better job
of not having to allocate for new messages by using
a link list.

* fmt
  • Loading branch information
LucioFranco authored Oct 4, 2019
1 parent bd2b4e0 commit ed3e7e9
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 5 deletions.
16 changes: 14 additions & 2 deletions tonic/src/codec/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use std::{
};
use tracing::{debug, trace};

const BUFFER_SIZE: usize = 8 * 1024;

/// Streaming requests and responses.
///
/// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface
Expand Down Expand Up @@ -70,6 +72,7 @@ impl<T> Streaming<T> {
{
Self::new(decoder, body, Direction::Request)
}

fn new<B, D>(decoder: D, body: B, direction: Direction) -> Self
where
B: Body + Send + 'static,
Expand All @@ -82,8 +85,7 @@ impl<T> Streaming<T> {
body: BoxBody::map_from(body),
state: State::ReadHeader,
direction,
// FIXME: update this with a reasonable size
buf: BytesMut::with_capacity(1024 * 1024),
buf: BytesMut::with_capacity(BUFFER_SIZE),
trailers: None,
}
}
Expand Down Expand Up @@ -234,6 +236,16 @@ impl<T> Stream for Streaming<T> {
};

if let Some(data) = chunk {
if data.remaining() > self.buf.remaining_mut() {
let amt = if data.remaining() > BUFFER_SIZE {
data.remaining()
} else {
BUFFER_SIZE
};

self.buf.reserve(amt);
}

self.buf.put(data);
} else {
// FIXME: improve buf usage.
Expand Down
4 changes: 3 additions & 1 deletion tonic/src/codec/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_codec::Encoder;

const BUFFER_SIZE: usize = 8 * 1024;

pub(crate) fn encode_server<T, U>(
encoder: T,
source: U,
Expand Down Expand Up @@ -39,7 +41,7 @@ where
U: Stream<Item = Result<T::Item, Status>>,
{
async_stream::stream! {
let mut buf = BytesMut::with_capacity(1024 * 1024);
let mut buf = BytesMut::with_capacity(BUFFER_SIZE);
futures_util::pin_mut!(source);

loop {
Expand Down
3 changes: 3 additions & 0 deletions tonic/src/codec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ mod encode;
#[cfg(feature = "prost")]
mod prost;

#[cfg(test)]
mod tests;

pub use self::decode::Streaming;
pub(crate) use self::encode::{encode_client, encode_server};
#[cfg(feature = "prost")]
Expand Down
4 changes: 2 additions & 2 deletions tonic/src/codec/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ where
}

/// A [`Encoder`] that knows how to encode `T`.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub struct ProstEncoder<T>(PhantomData<T>);

impl<T: Message> Encoder for ProstEncoder<T> {
Expand All @@ -60,7 +60,7 @@ impl<T: Message> Encoder for ProstEncoder<T> {
}

/// A [`Decoder`] that knows how to decode `U`.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub struct ProstDecoder<U>(PhantomData<U>);

impl<U: Message + Default> Decoder for ProstDecoder<U> {
Expand Down
112 changes: 112 additions & 0 deletions tonic/src/codec/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use super::{
encode_server,
prost::{ProstDecoder, ProstEncoder},
Streaming,
};
use crate::Status;
use bytes::{Buf, BufMut, Bytes, BytesMut, IntoBuf};
use http_body::Body;
use prost::Message;
use std::{
io::Cursor,
pin::Pin,
task::{Context, Poll},
};

#[derive(Clone, PartialEq, prost::Message)]
struct Msg {
#[prost(bytes, tag = "1")]
data: Vec<u8>,
}

#[tokio::test]
async fn decode() {
let decoder = ProstDecoder::<Msg>::default();

let data = Vec::from(&[0u8; 1024][..]);
let msg = Msg { data };

let mut buf = BytesMut::new();
let len = msg.encoded_len();

buf.reserve(len + 5);
buf.put_u8(0);
buf.put_u32_be(len as u32);
msg.encode(&mut buf).unwrap();

let body = MockBody(buf.freeze(), 0, 100);

let mut stream = Streaming::new_request(decoder, body);

while let Some(_) = stream.message().await.unwrap() {}
}

#[tokio::test]
async fn encode() {
let encoder = ProstEncoder::<Msg>::default();

let data = Vec::from(&[0u8; 1024][..]);
let msg = Msg { data };

let messages = std::iter::repeat(Ok::<_, Status>(msg)).take(10000);
let source = futures_util::stream::iter(messages);

let body = encode_server(encoder, source);

futures_util::pin_mut!(body);

while let Some(r) = body.next().await {
r.unwrap();
}
}

#[derive(Debug)]
struct MockBody(Bytes, usize, usize);

impl Body for MockBody {
type Data = Data;
type Error = Status;

fn poll_data(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
if self.1 > self.2 {
self.1 += 1;
let data = Data(self.0.clone().into_buf());
Poll::Ready(Some(Ok(data)))
} else {
Poll::Ready(None)
}
}

fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
drop(cx);
Poll::Ready(Ok(None))
}
}

struct Data(Cursor<Bytes>);

impl Into<Bytes> for Data {
fn into(self) -> Bytes {
self.0.into_inner()
}
}

impl Buf for Data {
fn remaining(&self) -> usize {
self.0.remaining()
}

fn bytes(&self) -> &[u8] {
self.0.bytes()
}

fn advance(&mut self, cnt: usize) {
self.0.advance(cnt)
}
}

0 comments on commit ed3e7e9

Please sign in to comment.