Skip to content

Commit

Permalink
feat(codec): compression support (#692)
Browse files Browse the repository at this point in the history
* Initial compression support

* Support configuring compression on `Server`

* Minor clean up

* Test that compression is actually happening

* Clean up some todos

* channels compressing requests

* Move compression to be on the codecs

* Test sending compressed request to server that doesn't support it

* Clean up a bit

* Compress server streams

* Compress client streams

* Bidirectional streaming compression

* Handle receiving unsupported encoding

* Clean up

* Add note to future self

* Support disabling compression for individual responses

* Add docs

* Add compression examples

* Disable compression behind feature flag

* Add some docs

* Make flate2 optional dependency

* Fix docs wording

* Format

* Reply with which encodings are supported

* Convert tests to use mocked io

* Fix lints

* Use separate counters

* Don't make a long stream

* Address review feedback
  • Loading branch information
davidpdrsn authored Jul 2, 2021
1 parent 7677ad6 commit 0583cff
Show file tree
Hide file tree
Showing 30 changed files with 2,190 additions and 97 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ members = [
"tests/integration_tests",
"tests/stream_conflict",
"tests/root-crate-path",
"tests/compression",
"tonic-web/tests/integration"
]

8 changes: 8 additions & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ path = "src/hyper_warp_multiplex/client.rs"
name = "hyper-warp-multiplex-server"
path = "src/hyper_warp_multiplex/server.rs"

[[bin]]
name = "compression-server"
path = "src/compression/server.rs"

[[bin]]
name = "compression-client"
path = "src/compression/client.rs"

[dependencies]
tonic = { path = "../tonic", features = ["tls"] }
prost = "0.7"
Expand Down
27 changes: 27 additions & 0 deletions examples/src/compression/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use hello_world::greeter_client::GreeterClient;
use hello_world::HelloRequest;
use tonic::transport::Channel;

pub mod hello_world {
tonic::include_proto!("helloworld");
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let channel = Channel::builder("http://[::1]:50051".parse().unwrap())
.connect()
.await
.unwrap();

let mut client = GreeterClient::new(channel).send_gzip().accept_gzip();

let request = tonic::Request::new(HelloRequest {
name: "Tonic".into(),
});

let response = client.say_hello(request).await?;

dbg!(response);

Ok(())
}
40 changes: 40 additions & 0 deletions examples/src/compression/server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use tonic::{transport::Server, Request, Response, Status};

use hello_world::greeter_server::{Greeter, GreeterServer};
use hello_world::{HelloReply, HelloRequest};

pub mod hello_world {
tonic::include_proto!("helloworld");
}

#[derive(Default)]
pub struct MyGreeter {}

#[tonic::async_trait]
impl Greeter for MyGreeter {
async fn say_hello(
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
println!("Got a request from {:?}", request.remote_addr());

let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name),
};
Ok(Response::new(reply))
}
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = "[::1]:50051".parse().unwrap();
let greeter = MyGreeter::default();

println!("GreeterServer listening on {}", addr);

let service = GreeterServer::new(greeter).send_gzip().accept_gzip();

Server::builder().add_service(service).serve(addr).await?;

Ok(())
}
24 changes: 24 additions & 0 deletions tests/compression/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[package]
name = "compression"
version = "0.1.0"
authors = ["Lucio Franco <luciofranco14@gmail.com>"]
edition = "2018"
publish = false
license = "MIT"

[dependencies]
tonic = { path = "../../tonic", features = ["compression"] }
prost = "0.7"
tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] }
tower = { version = "0.4", features = [] }
http-body = "0.4"
http = "0.2"
tokio-stream = { version = "0.1.5", features = ["net"] }
tower-http = { version = "0.1", features = ["map-response-body", "map-request-body"] }
bytes = "1"
futures = "0.3"
pin-project = "1.0"
hyper = "0.14"

[build-dependencies]
tonic-build = { path = "../../tonic-build", features = ["compression"] }
3 changes: 3 additions & 0 deletions tests/compression/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn main() {
tonic_build::compile_protos("proto/test.proto").unwrap();
}
19 changes: 19 additions & 0 deletions tests/compression/proto/test.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
syntax = "proto3";

package test;

import "google/protobuf/empty.proto";

service Test {
rpc CompressOutputUnary(google.protobuf.Empty) returns (SomeData);
rpc CompressInputUnary(SomeData) returns (google.protobuf.Empty);
rpc CompressOutputServerStream(google.protobuf.Empty) returns (stream SomeData);
rpc CompressInputClientStream(stream SomeData) returns (google.protobuf.Empty);
rpc CompressOutputClientStream(stream SomeData) returns (SomeData);
rpc CompressInputOutputBidirectionalStream(stream SomeData) returns (stream SomeData);
}

message SomeData {
// include a bunch of data so there actually is something to compress
bytes data = 1;
}
78 changes: 78 additions & 0 deletions tests/compression/src/bidirectional_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use super::*;

#[tokio::test(flavor = "multi_thread")]
async fn client_enabled_server_enabled() {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc = test_server::TestServer::new(Svc::default())
.accept_gzip()
.send_gzip();

let request_bytes_counter = Arc::new(AtomicUsize::new(0));
let response_bytes_counter = Arc::new(AtomicUsize::new(0));

fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip");
req
}

tokio::spawn({
let request_bytes_counter = request_bytes_counter.clone();
let response_bytes_counter = response_bytes_counter.clone();
async move {
Server::builder()
.layer(
ServiceBuilder::new()
.map_request(assert_right_encoding)
.layer(measure_request_body_size_layer(
request_bytes_counter.clone(),
))
.layer(MapResponseBodyLayer::new(move |body| {
util::CountBytesBody {
inner: body,
counter: response_bytes_counter.clone(),
}
}))
.into_inner(),
)
.add_service(svc)
.serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
MockStream(server),
)]))
.await
.unwrap();
}
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await)
.send_gzip()
.accept_gzip();

let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
let req = Request::new(stream);

let res = client
.compress_input_output_bidirectional_stream(req)
.await
.unwrap();

assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");

let mut stream: Streaming<SomeData> = res.into_inner();

stream
.next()
.await
.expect("stream empty")
.expect("item was error");

stream
.next()
.await
.expect("stream empty")
.expect("item was error");

assert!(request_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
}
167 changes: 167 additions & 0 deletions tests/compression/src/client_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
use super::*;
use http_body::Body as _;

#[tokio::test(flavor = "multi_thread")]
async fn client_enabled_server_enabled() {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc = test_server::TestServer::new(Svc::default()).accept_gzip();

let request_bytes_counter = Arc::new(AtomicUsize::new(0));

fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip");
req
}

tokio::spawn({
let request_bytes_counter = request_bytes_counter.clone();
async move {
Server::builder()
.layer(
ServiceBuilder::new()
.map_request(assert_right_encoding)
.layer(measure_request_body_size_layer(
request_bytes_counter.clone(),
))
.into_inner(),
)
.add_service(svc)
.serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
MockStream(server),
)]))
.await
.unwrap();
}
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await).send_gzip();

let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
let req = Request::new(Box::pin(stream));

client.compress_input_client_stream(req).await.unwrap();

let bytes_sent = request_bytes_counter.load(SeqCst);
assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
}

#[tokio::test(flavor = "multi_thread")]
async fn client_disabled_server_enabled() {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc = test_server::TestServer::new(Svc::default()).accept_gzip();

let request_bytes_counter = Arc::new(AtomicUsize::new(0));

fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
assert!(req.headers().get("grpc-encoding").is_none());
req
}

tokio::spawn({
let request_bytes_counter = request_bytes_counter.clone();
async move {
Server::builder()
.layer(
ServiceBuilder::new()
.map_request(assert_right_encoding)
.layer(measure_request_body_size_layer(
request_bytes_counter.clone(),
))
.into_inner(),
)
.add_service(svc)
.serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
MockStream(server),
)]))
.await
.unwrap();
}
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await);

let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
let req = Request::new(Box::pin(stream));

client.compress_input_client_stream(req).await.unwrap();

let bytes_sent = request_bytes_counter.load(SeqCst);
assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
}

#[tokio::test(flavor = "multi_thread")]
async fn client_enabled_server_disabled() {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc = test_server::TestServer::new(Svc::default());

tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
MockStream(server),
)]))
.await
.unwrap();
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await).send_gzip();

let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
let req = Request::new(Box::pin(stream));

let status = client.compress_input_client_stream(req).await.unwrap_err();

assert_eq!(status.code(), tonic::Code::Unimplemented);
assert_eq!(
status.message(),
"Content is compressed with `gzip` which isn't supported"
);
}

#[tokio::test(flavor = "multi_thread")]
async fn compressing_response_from_client_stream() {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc = test_server::TestServer::new(Svc::default()).send_gzip();

let response_bytes_counter = Arc::new(AtomicUsize::new(0));

tokio::spawn({
let response_bytes_counter = response_bytes_counter.clone();
async move {
Server::builder()
.layer(
ServiceBuilder::new()
.layer(MapResponseBodyLayer::new(move |body| {
util::CountBytesBody {
inner: body,
counter: response_bytes_counter.clone(),
}
}))
.into_inner(),
)
.add_service(svc)
.serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
MockStream(server),
)]))
.await
.unwrap();
}
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip();

let stream = futures::stream::iter(vec![]);
let req = Request::new(Box::pin(stream));

let res = client.compress_output_client_stream(req).await.unwrap();
assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
let bytes_sent = response_bytes_counter.load(SeqCst);
assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
}
Loading

0 comments on commit 0583cff

Please sign in to comment.