Skip to content

Commit c8a12ef

Browse files
authored
fix: improve gzip request decompression middleware (#2077)
1 parent cc1072b commit c8a12ef

File tree

4 files changed

+101
-18
lines changed

4 files changed

+101
-18
lines changed

router/core/graph_server.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,13 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC
237237
)
238238
})))
239239

240-
// Request traffic shaping related middlewares
241-
httpRouter.Use(rmiddleware.RequestSize(int64(s.routerTrafficConfig.MaxRequestBodyBytes)))
242240
if s.routerTrafficConfig.DecompressionEnabled {
243241
httpRouter.Use(rmiddleware.HandleCompression(s.logger))
244242
}
245243

244+
// Request traffic shaping related middlewares, happens after decompression to prevent unbounded decompression attacks
245+
httpRouter.Use(rmiddleware.RequestSize(int64(s.routerTrafficConfig.MaxRequestBodyBytes)))
246+
246247
httpRouter.Use(middleware.RequestID)
247248
httpRouter.Use(middleware.RealIP)
248249
if s.corsOptions.Enabled {

router/internal/middleware/compression.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package middleware
22

33
import (
44
"compress/gzip"
5+
56
"net/http"
67
"strings"
78

@@ -32,14 +33,23 @@ func HandleCompression(logger *zap.Logger) func(http.Handler) http.Handler {
3233
return
3334
}
3435

36+
originalBody := r.Body
37+
3538
defer func() {
3639
if err := gzr.Close(); err != nil {
3740
logger.Error("failed to close gzip reader", zap.Error(err))
3841
}
42+
43+
if err := originalBody.Close(); err != nil {
44+
logger.Error("failed to close original body", zap.Error(err))
45+
}
3946
}()
4047

4148
r.Body = gzr
4249

50+
// Content-Length is no longer valid after decompression
51+
r.Header.Del("Content-Length")
52+
r.ContentLength = -1
4353
case "":
4454
default:
4555
http.Error(w, "unsupported media type", http.StatusUnsupportedMediaType)

router/internal/middleware/compression_test.go

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
package middleware
22

33
import (
4-
"compress/gzip"
4+
"bytes"
55
"io"
66
"net/http"
77
"net/http/httptest"
88
"strings"
99
"testing"
1010

11+
"compress/gzip"
12+
1113
"github.com/stretchr/testify/require"
1214
"go.uber.org/zap"
1315
)
@@ -45,8 +47,7 @@ func TestHandleCompression(t *testing.T) {
4547
w.WriteHeader(http.StatusOK)
4648
})
4749

48-
req, err := http.NewRequest(tc.method, "/", strings.NewReader("test"))
49-
require.NoError(t, err)
50+
req := httptest.NewRequest(tc.method, "/", strings.NewReader("test"))
5051

5152
HandleCompression(zap.NewNop())(next).ServeHTTP(recorder, req)
5253

@@ -64,8 +65,7 @@ func TestHandleCompression(t *testing.T) {
6465
w.WriteHeader(http.StatusOK)
6566
})
6667

67-
req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
68-
require.NoError(t, err)
68+
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
6969
req.Header.Set("Content-Encoding", "gzip, deflate")
7070

7171
HandleCompression(zap.NewNop())(next).ServeHTTP(recorder, req)
@@ -92,16 +92,14 @@ func TestHandleCompression(t *testing.T) {
9292

9393
// create gzip compressed request
9494

95-
var sb strings.Builder
96-
w := gzip.NewWriter(&sb)
95+
var buf bytes.Buffer
96+
w := gzip.NewWriter(&buf)
9797

9898
_, err := w.Write([]byte("test"))
9999
require.NoError(t, err)
100-
101100
require.NoError(t, w.Close())
102101

103-
req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(sb.String()))
104-
require.NoError(t, err)
102+
req := httptest.NewRequest(http.MethodPost, "/", &buf)
105103
req.Header.Set("Content-Encoding", "gzip")
106104

107105
HandleCompression(zap.NewNop())(next).ServeHTTP(recorder, req)
@@ -118,8 +116,7 @@ func TestHandleCompression(t *testing.T) {
118116
w.WriteHeader(http.StatusOK)
119117
})
120118

121-
req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
122-
require.NoError(t, err)
119+
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
123120
req.Header.Set("Content-Encoding", "gzip")
124121

125122
HandleCompression(zap.NewNop())(next).ServeHTTP(recorder, req)
@@ -136,8 +133,7 @@ func TestHandleCompression(t *testing.T) {
136133
w.WriteHeader(http.StatusOK)
137134
})
138135

139-
req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
140-
require.NoError(t, err)
136+
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
141137
req.Header.Set("Content-Encoding", "deflate")
142138

143139
HandleCompression(zap.NewNop())(next).ServeHTTP(recorder, req)
@@ -155,8 +151,7 @@ func TestHandleCompression(t *testing.T) {
155151
w.WriteHeader(http.StatusOK)
156152
})
157153

158-
req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
159-
require.NoError(t, err)
154+
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
160155

161156
HandleCompression(zap.NewNop())(next).ServeHTTP(recorder, req)
162157

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package middleware
2+
3+
import (
4+
"bytes"
5+
"compress/gzip"
6+
"io"
7+
"net/http"
8+
"net/http/httptest"
9+
"strings"
10+
"testing"
11+
12+
"github.com/go-chi/chi/v5"
13+
"github.com/stretchr/testify/require"
14+
"go.uber.org/zap"
15+
)
16+
17+
func TestRequestSizeAndCompression(t *testing.T) {
18+
// Define maximum allowed request size
19+
const maxRequestSize = 1024 // Example: 1 KB
20+
21+
// Chain the compression and request_size middlewares
22+
r := chi.NewMux()
23+
r.Use(HandleCompression(zap.NewNop()))
24+
r.Use(RequestSize(maxRequestSize))
25+
26+
t.Run("request size limiter should not allow requests exceeding allowed maximum request size", func(t *testing.T) {
27+
// Recorder to capture the response
28+
recorder := httptest.NewRecorder()
29+
30+
// Write a large payload that exceeds the maxRequestSize after decompression
31+
largePayload := strings.Repeat("A", maxRequestSize*10) // 10x the max size
32+
33+
// Create the request with the gzip bomb payload
34+
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largePayload))
35+
36+
r.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
37+
_, err := io.Copy(io.Discard, r.Body)
38+
require.ErrorContains(t, err, "http: request body too large")
39+
w.WriteHeader(http.StatusRequestEntityTooLarge)
40+
}))
41+
42+
r.ServeHTTP(recorder, req)
43+
44+
// Assert that the response status is 413 Payload Too Large
45+
require.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
46+
})
47+
48+
t.Run("request size limiter should reject gzip bomb exceeding allowed maximum request size", func(t *testing.T) {
49+
// Recorder to capture the response
50+
recorder := httptest.NewRecorder()
51+
52+
// Create a gzip bomb payload
53+
var buf bytes.Buffer
54+
w := gzip.NewWriter(&buf)
55+
56+
// Write a large payload that exceeds the maxRequestSize after decompression
57+
largePayload := strings.Repeat("A", maxRequestSize*10) // 10x the max size
58+
_, err := w.Write([]byte(largePayload))
59+
require.NoError(t, err)
60+
require.NoError(t, w.Close())
61+
62+
// Create the request with the gzip bomb payload
63+
req := httptest.NewRequest(http.MethodPost, "/", &buf)
64+
req.Header.Set("Content-Encoding", "gzip")
65+
66+
r.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
67+
_, err := io.Copy(io.Discard, r.Body)
68+
require.ErrorContains(t, err, "http: request body too large")
69+
w.WriteHeader(http.StatusRequestEntityTooLarge)
70+
}))
71+
72+
r.ServeHTTP(recorder, req)
73+
74+
// Assert that the response status is 413 Payload Too Large
75+
require.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
76+
})
77+
}

0 commit comments

Comments
 (0)