diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md
index 86728227..17757e18 100644
--- a/tower-http/CHANGELOG.md
+++ b/tower-http/CHANGELOG.md
@@ -25,10 +25,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Fixed
- Don't include identity in Content-Encoding header ([#317])
+- **compression:** Do compress SVGs ([#321])
[#290]: https://github.com/tower-rs/tower-http/pull/290
[#283]: https://github.com/tower-rs/tower-http/pull/283
[#317]: https://github.com/tower-rs/tower-http/pull/317
+[#321]: https://github.com/tower-rs/tower-http/pull/321
# 0.3.5 (December 02, 2022)
diff --git a/tower-http/src/compression/mod.rs b/tower-http/src/compression/mod.rs
index 14da1c1d..7f7c143f 100644
--- a/tower-http/src/compression/mod.rs
+++ b/tower-http/src/compression/mod.rs
@@ -83,10 +83,13 @@ pub use self::{
#[cfg(test)]
mod tests {
+ use crate::compression::predicate::SizeAbove;
+
use super::*;
use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
use bytes::BytesMut;
use flate2::read::GzDecoder;
+ use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE};
use http_body::Body as _;
use hyper::{Body, Error, Request, Response, Server};
use std::sync::{Arc, RwLock};
@@ -281,4 +284,54 @@ mod tests {
}
assert!(String::from_utf8(data.to_vec()).is_err());
}
+
+ #[tokio::test]
+ async fn doesnt_compress_images() {
+ async fn handle(_req: Request
) -> Result, Error> {
+ let mut res = Response::new(Body::from(
+ "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
+ ));
+ res.headers_mut()
+ .insert(CONTENT_TYPE, "image/png".parse().unwrap());
+ Ok(res)
+ }
+
+ let svc = Compression::new(service_fn(handle));
+
+ let res = svc
+ .oneshot(
+ Request::builder()
+ .header(ACCEPT_ENCODING, "gzip")
+ .body(Body::empty())
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+ assert!(res.headers().get(CONTENT_ENCODING).is_none());
+ }
+
+ #[tokio::test]
+ async fn does_compress_svg() {
+ async fn handle(_req: Request) -> Result, Error> {
+ let mut res = Response::new(Body::from(
+ "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
+ ));
+ res.headers_mut()
+ .insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap());
+ Ok(res)
+ }
+
+ let svc = Compression::new(service_fn(handle));
+
+ let res = svc
+ .oneshot(
+ Request::builder()
+ .header(ACCEPT_ENCODING, "gzip")
+ .body(Body::empty())
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+ assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
+ }
}
diff --git a/tower-http/src/compression/predicate.rs b/tower-http/src/compression/predicate.rs
index 2a1d7b45..2bb37c22 100644
--- a/tower-http/src/compression/predicate.rs
+++ b/tower-http/src/compression/predicate.rs
@@ -145,7 +145,7 @@ impl Predicate for DefaultPredicate {
pub struct SizeAbove(u16);
impl SizeAbove {
- const DEFAULT_MIN_SIZE: u16 = 32;
+ pub(crate) const DEFAULT_MIN_SIZE: u16 = 32;
/// Create a new `SizeAbove` predicate that will only compress responses larger than
/// `min_size_bytes`.
@@ -185,23 +185,35 @@ impl Predicate for SizeAbove {
/// Predicate that wont allow responses with a specific `content-type` to be compressed.
#[derive(Clone, Debug)]
-pub struct NotForContentType(Str);
+pub struct NotForContentType {
+ content_type: Str,
+ exception: Option,
+}
impl NotForContentType {
/// Predicate that wont compress gRPC responses.
pub const GRPC: Self = Self::const_new("application/grpc");
/// Predicate that wont compress images.
- pub const IMAGES: Self = Self::const_new("image/");
+ pub const IMAGES: Self = Self {
+ content_type: Str::Static("image/"),
+ exception: Some(Str::Static("image/svg+xml")),
+ };
/// Create a new `NotForContentType`.
pub fn new(content_type: &str) -> Self {
- Self(Str::Shared(content_type.into()))
+ Self {
+ content_type: Str::Shared(content_type.into()),
+ exception: None,
+ }
}
/// Create a new `NotForContentType` from a static string.
pub const fn const_new(content_type: &'static str) -> Self {
- Self(Str::Static(content_type))
+ Self {
+ content_type: Str::Static(content_type),
+ exception: None,
+ }
}
}
@@ -210,11 +222,13 @@ impl Predicate for NotForContentType {
where
B: Body,
{
- let str = match &self.0 {
- Str::Static(str) => *str,
- Str::Shared(arc) => &*arc,
- };
- !content_type(response).starts_with(str)
+ if let Some(except) = &self.exception {
+ if content_type(response) == except.as_str() {
+ return true;
+ }
+ }
+
+ !content_type(response).starts_with(self.content_type.as_str())
}
}
@@ -224,6 +238,15 @@ enum Str {
Shared(Arc),
}
+impl Str {
+ fn as_str(&self) -> &str {
+ match self {
+ Str::Static(s) => s,
+ Str::Shared(s) => s,
+ }
+ }
+}
+
impl fmt::Debug for Str {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {