Skip to content

Commit

Permalink
feat(core): compression (viz-rs#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
fundon authored Sep 27, 2022
1 parent 71c64b0 commit fa25f45
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 1 deletion.
1 change: 1 addition & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ members = [
"static-routes",
"routing/todos",
"otel/*",
"compression"
]
exclude = ["tls", "target"]

Expand Down
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Here you can find a lot of small crabs 🦀.
* [Session](session)
* [CSRF](csrf)
* [CORS](cors)
* [Compression response body](compresssion)
* [HTTPS/TLS - rustls](rustls)
* [Defined a static router](static-routes)
* [Todos](routing/todos)
Expand Down
10 changes: 10 additions & 0 deletions examples/compression/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "compression"
version = "0.1.0"
edition = "2021"
publish = false

[dependencies]
viz = { path = "../../viz", features = ["compression"] }

tokio = { version = "1.0", features = ["rt-multi-thread", "macros"] }
25 changes: 25 additions & 0 deletions examples/compression/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#![deny(warnings)]

use std::net::SocketAddr;

use viz::{get, middleware::compression, Request, Result, Router, Server, ServiceMaker};

async fn index(_req: Request) -> Result<&'static str> {
Ok("Hello, World!")
}

#[tokio::main]
async fn main() -> Result<()> {
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("listening on {}", addr);

let app = Router::new()
.route("/", get(index))
.with(compression::Config::default());

if let Err(err) = Server::bind(&addr).serve(ServiceMaker::from(app)).await {
println!("{}", err);
}

Ok(())
}
7 changes: 6 additions & 1 deletion viz-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ default = [
"multipart",
"websocket",
"cookie",
"session"
"session",
]

state = []
Expand All @@ -46,6 +46,8 @@ session = ["cookie", "json", "dep:sessions-core"]
csrf = ["cookie", "dep:base64", "dep:getrandom"]
cors = []

compression = ["dep:tokio-util", "dep:async-compression"]

otel = ["dep:opentelemetry", "dep:opentelemetry-semantic-conventions"]
otel-tracing = ["otel", "opentelemetry?/trace"]
otel-metrics = ["otel", "opentelemetry?/metrics"]
Expand Down Expand Up @@ -74,6 +76,9 @@ sessions-core = { version = "0.3.4", optional = true }
getrandom = { version = "0.2.7", optional = true }
base64 = { version = "0.13.0", optional = true }

# Compression
async-compression = { version = "0.3.14", features = ["tokio", "gzip", "brotli", "deflate"], optional = true }

# OpenTelemetry
opentelemetry = { version = "0.18.0", default-features = false, optional = true }
opentelemetry-semantic-conventions = { version = "0.10.0", optional = true }
Expand Down
161 changes: 161 additions & 0 deletions viz-core/src/middleware/compression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
//! Compression Middleware.
use std::{io, str::FromStr};

use async_compression::tokio::bufread;
use futures_util::TryStreamExt;
use tokio_util::io::{ReaderStream, StreamReader};

use crate::{
async_trait,
header::{HeaderValue, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH},
Body, Handler, IntoResponse, Request, Response, Result, Transform,
};

/// Compress response body.
#[derive(Debug, Default)]
pub struct Config;

impl<H> Transform<H> for Config
where
H: Clone,
{
type Output = CompressionMiddleware<H>;

fn transform(&self, h: H) -> Self::Output {
CompressionMiddleware { h }
}
}

/// Compression middleware.
#[derive(Clone, Debug)]
pub struct CompressionMiddleware<H> {
h: H,
}

#[async_trait]
impl<H, O> Handler<Request> for CompressionMiddleware<H>
where
O: IntoResponse,
H: Handler<Request, Output = Result<O>> + Clone,
{
type Output = Result<Response>;

async fn call(&self, req: Request) -> Self::Output {
let accept_encoding = req
.headers()
.get(ACCEPT_ENCODING)
.and_then(|v| v.to_str().ok())
.and_then(parse_accept_encoding);

let raw = self.h.call(req).await?;

Ok(match accept_encoding {
Some(algo) => Compress::new(raw, algo).into_response(),
None => raw.into_response(),
})
}
}

/// Compresses the response body with the specified algorithm
/// and sets the `Content-Encoding` header.
#[derive(Debug)]
pub struct Compress<T> {
inner: T,
algo: ContentCoding,
}

impl<T> Compress<T> {
/// Creates a compressed response with the specified algorithm.
pub fn new(inner: T, algo: ContentCoding) -> Self {
Self { inner, algo }
}
}

impl<T: IntoResponse> IntoResponse for Compress<T> {
fn into_response(self) -> Response {
let mut res = self.inner.into_response();

match self.algo {
ContentCoding::Gzip | ContentCoding::Deflate | ContentCoding::Brotli => {
res = res.map(|body| {
let body = StreamReader::new(body.map_err(map_hyper_err));
if self.algo == ContentCoding::Gzip {
Body::wrap_stream(ReaderStream::new(bufread::GzipEncoder::new(body)))
} else if self.algo == ContentCoding::Deflate {
Body::wrap_stream(ReaderStream::new(bufread::DeflateEncoder::new(body)))
} else {
Body::wrap_stream(ReaderStream::new(bufread::BrotliEncoder::new(body)))
}
});
res.headers_mut()
.append(CONTENT_ENCODING, HeaderValue::from_static(self.algo.into()));
res.headers_mut().remove(CONTENT_LENGTH);
res
}
ContentCoding::Any => res,
}
}
}

/// [ContentCoding]
///
/// [ContentCoding]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
#[derive(Debug, PartialEq)]
pub enum ContentCoding {
/// gzip
Gzip,
/// deflate
Deflate,
/// brotli
Brotli,
/// *
Any,
}

impl FromStr for ContentCoding {
type Err = ();

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.eq_ignore_ascii_case("deflate") {
Ok(ContentCoding::Deflate)
} else if s.eq_ignore_ascii_case("gzip") {
Ok(ContentCoding::Gzip)
} else if s == "*" {
Ok(ContentCoding::Any)
} else {
Err(())
}
}
}

impl From<ContentCoding> for &'static str {
fn from(cc: ContentCoding) -> Self {
match cc {
ContentCoding::Gzip => "gzip",
ContentCoding::Deflate => "deflate",
ContentCoding::Brotli => "brotli",
ContentCoding::Any => "*",
}
}
}

fn parse_accept_encoding(s: &str) -> Option<ContentCoding> {
s.split(',')
.map(str::trim)
.filter_map(|v| {
Some(match v.split_once(";q=") {
Some((c, q)) => (
c.parse::<ContentCoding>().ok()?,
q.parse::<f32>().ok()? * 1000.,
),
None => (v.parse::<ContentCoding>().ok()?, 1000.),
})
})
.max_by_key(|(_, q)| *q as u16)
.map(|(c, _)| c)
}

fn map_hyper_err(e: hyper::Error) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
}
3 changes: 3 additions & 0 deletions viz-core/src/middleware/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ pub mod otel;

#[cfg(feature = "cookie")]
pub mod helper;

#[cfg(feature = "compression")]
pub mod compression;
2 changes: 2 additions & 0 deletions viz/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ session = ["cookie", "viz-core/session"]
csrf = ["cookie", "viz-core/csrf"]
cors = ["viz-core/cors"]

compression = ["viz-core/compression"]

unix-socket = []

macros = ["dep:viz-macros"]
Expand Down

0 comments on commit fa25f45

Please sign in to comment.