@@ -7,24 +7,58 @@ package gzhttp
77import (
88 "io"
99 "net/http"
10+ "strings"
1011 "sync"
1112
1213 "github.com/klauspost/compress/gzip"
14+ "github.com/klauspost/compress/zstd"
1315)
1416
15- // Transport will wrap a transport with a custom gzip handler
17+ // Transport will wrap a transport with a custom handler
1618// that will request gzip and automatically decompress it.
1719// Using this is significantly faster than using the default transport.
18- func Transport (parent http.RoundTripper ) http.RoundTripper {
19- return gzRoundtripper {parent : parent }
20+ func Transport (parent http.RoundTripper , opts ... transportOption ) http.RoundTripper {
21+ g := gzRoundtripper {parent : parent , withZstd : true , withGzip : true }
22+ for _ , o := range opts {
23+ o (& g )
24+ }
25+ var ae []string
26+ if g .withZstd {
27+ ae = append (ae , "zstd" )
28+ }
29+ if g .withGzip {
30+ ae = append (ae , "gzip" )
31+ }
32+ g .acceptEncoding = strings .Join (ae , "," )
33+ return & g
34+ }
35+
36+ type transportOption func (c * gzRoundtripper )
37+
38+ // TransportEnableZstd will send Zstandard as a compression option to the server.
39+ // Enabled by default, but may be disabled if future problems arise.
40+ func TransportEnableZstd (b bool ) transportOption {
41+ return func (c * gzRoundtripper ) {
42+ c .withZstd = b
43+ }
44+ }
45+
46+ // TransportEnableGzip will send Gzip as a compression option to the server.
47+ // Enabled by default.
48+ func TransportEnableGzip (b bool ) transportOption {
49+ return func (c * gzRoundtripper ) {
50+ c .withGzip = b
51+ }
2052}
2153
2254type gzRoundtripper struct {
23- parent http.RoundTripper
55+ parent http.RoundTripper
56+ acceptEncoding string
57+ withZstd , withGzip bool
2458}
2559
26- func (g gzRoundtripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
27- var requestedGzip bool
60+ func (g * gzRoundtripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
61+ var requestedComp bool
2862 if req .Header .Get ("Accept-Encoding" ) == "" &&
2963 req .Header .Get ("Range" ) == "" &&
3064 req .Method != "HEAD" {
@@ -40,20 +74,31 @@ func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
4074 // We don't request gzip if the request is for a range, since
4175 // auto-decoding a portion of a gzipped document will just fail
4276 // anyway. See https://golang.org/issue/8923
43- requestedGzip = true
44- req .Header .Set ("Accept-Encoding" , "gzip" )
77+ requestedComp = len ( g . acceptEncoding ) > 0
78+ req .Header .Set ("Accept-Encoding" , g . acceptEncoding )
4579 }
80+
4681 resp , err := g .parent .RoundTrip (req )
47- if err != nil || ! requestedGzip {
82+ if err != nil || ! requestedComp {
4883 return resp , err
4984 }
50- if asciiEqualFold (resp .Header .Get ("Content-Encoding" ), "gzip" ) {
85+
86+ // Decompress
87+ if g .withGzip && asciiEqualFold (resp .Header .Get ("Content-Encoding" ), "gzip" ) {
5188 resp .Body = & gzipReader {body : resp .Body }
5289 resp .Header .Del ("Content-Encoding" )
5390 resp .Header .Del ("Content-Length" )
5491 resp .ContentLength = - 1
5592 resp .Uncompressed = true
5693 }
94+ if g .withZstd && asciiEqualFold (resp .Header .Get ("Content-Encoding" ), "zstd" ) {
95+ resp .Body = & zstdReader {body : resp .Body }
96+ resp .Header .Del ("Content-Encoding" )
97+ resp .Header .Del ("Content-Length" )
98+ resp .ContentLength = - 1
99+ resp .Uncompressed = true
100+ }
101+
57102 return resp , nil
58103}
59104
@@ -114,3 +159,53 @@ func lower(b byte) byte {
114159 }
115160 return b
116161}
162+
163+ // zstdReaderPool pools zstd decoders.
164+ var zstdReaderPool sync.Pool
165+
166+ // zstdReader wraps a response body so it can lazily
167+ // call gzip.NewReader on the first call to Read
168+ type zstdReader struct {
169+ body io.ReadCloser // underlying HTTP/1 response body framing
170+ zr * zstd.Decoder // lazily-initialized gzip reader
171+ zerr error // any error from zstd.NewReader; sticky
172+ }
173+
174+ func (zr * zstdReader ) Read (p []byte ) (n int , err error ) {
175+ if zr .zerr != nil {
176+ return 0 , zr .zerr
177+ }
178+ if zr .zr == nil {
179+ if zr .zerr == nil {
180+ reader , ok := zstdReaderPool .Get ().(* zstd.Decoder )
181+ if ok {
182+ zr .zerr = reader .Reset (zr .body )
183+ zr .zr = reader
184+ } else {
185+ zr .zr , zr .zerr = zstd .NewReader (zr .body , zstd .WithDecoderLowmem (true ), zstd .WithDecoderMaxWindow (32 << 20 ), zstd .WithDecoderConcurrency (1 ))
186+ }
187+ }
188+ if zr .zerr != nil {
189+ return 0 , zr .zerr
190+ }
191+ }
192+ n , err = zr .zr .Read (p )
193+ if err != nil {
194+ // Usually this will be io.EOF,
195+ // stash the decoder and keep the error.
196+ zr .zr .Reset (nil )
197+ zstdReaderPool .Put (zr .zr )
198+ zr .zr = nil
199+ zr .zerr = err
200+ }
201+ return
202+ }
203+
204+ func (zr * zstdReader ) Close () error {
205+ if zr .zr != nil {
206+ zr .zr .Reset (nil )
207+ zstdReaderPool .Put (zr .zr )
208+ zr .zr = nil
209+ }
210+ return zr .body .Close ()
211+ }
0 commit comments