Skip to content

Commit bf9102c

Browse files
authored
gzhttp: Add zstd to transport (#400)
``` goos: windows goarch: amd64 pkg: github.com/klauspost/compress/gzhttp cpu: AMD Ryzen 9 3950X 16-Core Processor BenchmarkTransport BenchmarkTransport/gzhttp-32 2179 503006 ns/op 259.61 MB/s 25.67 pct 9623 B/op 74 allocs/op BenchmarkTransport/stdlib-32 2001 596271 ns/op 219.00 MB/s 25.67 pct 52275 B/op 92 allocs/op BenchmarkTransport/zstd-32 3404 343757 ns/op 379.87 MB/s 24.44 pct 5358 B/op 69 allocs/op BenchmarkTransport/gzhttp-par-32 47127 25402 ns/op 5140.75 MB/s 25.67 pct 9598 B/op 72 allocs/op BenchmarkTransport/stdlib-par-32 39920 30834 ns/op 4235.03 MB/s 25.67 pct 52269 B/op 90 allocs/op BenchmarkTransport/zstd-par-32 68941 17277 ns/op 7558.02 MB/s 24.44 pct 5436 B/op 67 allocs/op PASS Process finished with the exit code 0 ``` * [x] Tests added.
1 parent 2982376 commit bf9102c

File tree

4 files changed

+318
-18
lines changed

4 files changed

+318
-18
lines changed

gzhttp/gzip_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,9 @@ func newTestHandler(body []byte) http.Handler {
11331133
case "/gzipped":
11341134
w.Header().Set("Content-Encoding", "gzip")
11351135
w.Write(body)
1136+
case "/zstd":
1137+
w.Header().Set("Content-Encoding", "zstd")
1138+
w.Write(body)
11361139
default:
11371140
w.Write(body)
11381141
}

gzhttp/transport.go

Lines changed: 105 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,58 @@ package gzhttp
77
import (
88
"io"
99
"net/http"
10+
"strings"
1011
"sync"
1112

1213
"github.com/klauspost/compress/gzip"
14+
"github.com/klauspost/compress/zstd"
1315
)
1416

15-
// Transport will wrap a transport with a custom gzip handler
17+
// Transport will wrap a transport with a custom handler
1618
// that will request gzip and automatically decompress it.
1719
// Using this is significantly faster than using the default transport.
18-
func Transport(parent http.RoundTripper) http.RoundTripper {
19-
return gzRoundtripper{parent: parent}
20+
func Transport(parent http.RoundTripper, opts ...transportOption) http.RoundTripper {
21+
g := gzRoundtripper{parent: parent, withZstd: true, withGzip: true}
22+
for _, o := range opts {
23+
o(&g)
24+
}
25+
var ae []string
26+
if g.withZstd {
27+
ae = append(ae, "zstd")
28+
}
29+
if g.withGzip {
30+
ae = append(ae, "gzip")
31+
}
32+
g.acceptEncoding = strings.Join(ae, ",")
33+
return &g
34+
}
35+
36+
type transportOption func(c *gzRoundtripper)
37+
38+
// TransportEnableZstd will send Zstandard as a compression option to the server.
39+
// Enabled by default, but may be disabled if future problems arise.
40+
func TransportEnableZstd(b bool) transportOption {
41+
return func(c *gzRoundtripper) {
42+
c.withZstd = b
43+
}
44+
}
45+
46+
// TransportEnableGzip will send Gzip as a compression option to the server.
47+
// Enabled by default.
48+
func TransportEnableGzip(b bool) transportOption {
49+
return func(c *gzRoundtripper) {
50+
c.withGzip = b
51+
}
2052
}
2153

2254
type gzRoundtripper struct {
23-
parent http.RoundTripper
55+
parent http.RoundTripper
56+
acceptEncoding string
57+
withZstd, withGzip bool
2458
}
2559

26-
func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
27-
var requestedGzip bool
60+
func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
61+
var requestedComp bool
2862
if req.Header.Get("Accept-Encoding") == "" &&
2963
req.Header.Get("Range") == "" &&
3064
req.Method != "HEAD" {
@@ -40,20 +74,31 @@ func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
4074
// We don't request gzip if the request is for a range, since
4175
// auto-decoding a portion of a gzipped document will just fail
4276
// anyway. See https://golang.org/issue/8923
43-
requestedGzip = true
44-
req.Header.Set("Accept-Encoding", "gzip")
77+
requestedComp = len(g.acceptEncoding) > 0
78+
req.Header.Set("Accept-Encoding", g.acceptEncoding)
4579
}
80+
4681
resp, err := g.parent.RoundTrip(req)
47-
if err != nil || !requestedGzip {
82+
if err != nil || !requestedComp {
4883
return resp, err
4984
}
50-
if asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
85+
86+
// Decompress
87+
if g.withGzip && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
5188
resp.Body = &gzipReader{body: resp.Body}
5289
resp.Header.Del("Content-Encoding")
5390
resp.Header.Del("Content-Length")
5491
resp.ContentLength = -1
5592
resp.Uncompressed = true
5693
}
94+
if g.withZstd && asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") {
95+
resp.Body = &zstdReader{body: resp.Body}
96+
resp.Header.Del("Content-Encoding")
97+
resp.Header.Del("Content-Length")
98+
resp.ContentLength = -1
99+
resp.Uncompressed = true
100+
}
101+
57102
return resp, nil
58103
}
59104

@@ -114,3 +159,53 @@ func lower(b byte) byte {
114159
}
115160
return b
116161
}
162+
163+
// zstdReaderPool pools zstd decoders.
164+
var zstdReaderPool sync.Pool
165+
166+
// zstdReader wraps a response body so it can lazily
167+
// call gzip.NewReader on the first call to Read
168+
type zstdReader struct {
169+
body io.ReadCloser // underlying HTTP/1 response body framing
170+
zr *zstd.Decoder // lazily-initialized gzip reader
171+
zerr error // any error from zstd.NewReader; sticky
172+
}
173+
174+
func (zr *zstdReader) Read(p []byte) (n int, err error) {
175+
if zr.zerr != nil {
176+
return 0, zr.zerr
177+
}
178+
if zr.zr == nil {
179+
if zr.zerr == nil {
180+
reader, ok := zstdReaderPool.Get().(*zstd.Decoder)
181+
if ok {
182+
zr.zerr = reader.Reset(zr.body)
183+
zr.zr = reader
184+
} else {
185+
zr.zr, zr.zerr = zstd.NewReader(zr.body, zstd.WithDecoderLowmem(true), zstd.WithDecoderMaxWindow(32<<20), zstd.WithDecoderConcurrency(1))
186+
}
187+
}
188+
if zr.zerr != nil {
189+
return 0, zr.zerr
190+
}
191+
}
192+
n, err = zr.zr.Read(p)
193+
if err != nil {
194+
// Usually this will be io.EOF,
195+
// stash the decoder and keep the error.
196+
zr.zr.Reset(nil)
197+
zstdReaderPool.Put(zr.zr)
198+
zr.zr = nil
199+
zr.zerr = err
200+
}
201+
return
202+
}
203+
204+
func (zr *zstdReader) Close() error {
205+
if zr.zr != nil {
206+
zr.zr.Reset(nil)
207+
zstdReaderPool.Put(zr.zr)
208+
zr.zr = nil
209+
}
210+
return zr.body.Close()
211+
}

0 commit comments

Comments
 (0)