Skip to content

Commit a7ac232

Browse files
committed
improve middleware test format, add more tests
1 parent c2b050f commit a7ac232

File tree

1 file changed

+93
-20
lines changed

1 file changed

+93
-20
lines changed

runnables/httpserver/middleware/middleware_test.go

Lines changed: 93 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package middleware
22

33
import (
4+
"errors"
45
"net/http"
56
"net/http/httptest"
67
"testing"
@@ -20,12 +21,9 @@ func TestResponseWriter(t *testing.T) {
2021
// Write a custom status code
2122
rw.WriteHeader(http.StatusNotFound)
2223

23-
// Check that the status code was set in our wrapper
24-
assert.Equal(t, http.StatusNotFound, rw.statusCode)
25-
// Check that the status code was passed to the underlying ResponseWriter
26-
assert.Equal(t, http.StatusNotFound, rec.Code)
27-
// Check that written flag was set
28-
assert.True(t, rw.written)
24+
assert.Equal(t, http.StatusNotFound, rw.statusCode, "status code should be set in wrapper")
25+
assert.Equal(t, http.StatusNotFound, rec.Code, "status code should be passed to underlying ResponseWriter")
26+
assert.True(t, rw.written, "written flag should be set")
2927
})
3028

3129
t.Run("Write sets written flag", func(t *testing.T) {
@@ -41,12 +39,9 @@ func TestResponseWriter(t *testing.T) {
4139
require.NoError(t, err)
4240
require.Equal(t, 4, n)
4341

44-
// Check written flag is set
45-
assert.True(t, rw.written)
46-
// Status code should remain default
47-
assert.Equal(t, http.StatusOK, rw.statusCode)
48-
// Content should be written to underlying ResponseWriter
49-
assert.Equal(t, "test", rec.Body.String())
42+
assert.True(t, rw.written, "written flag should be set")
43+
assert.Equal(t, http.StatusOK, rw.statusCode, "status code should remain default")
44+
assert.Equal(t, "test", rec.Body.String(), "content should be written to underlying ResponseWriter")
5045
})
5146

5247
t.Run("Write doesn't change written flag if already set", func(t *testing.T) {
@@ -61,8 +56,7 @@ func TestResponseWriter(t *testing.T) {
6156
_, err := rw.Write([]byte("test"))
6257
require.NoError(t, err)
6358

64-
// Check written flag is still true
65-
assert.True(t, rw.written)
59+
assert.True(t, rw.written, "written flag should still be true")
6660
})
6761

6862
t.Run("Integration with middleware", func(t *testing.T) {
@@ -74,8 +68,7 @@ func TestResponseWriter(t *testing.T) {
7468
statusCode: http.StatusOK,
7569
}
7670
next(rw, r)
77-
// We can now access the statusCode that was set by the handler
78-
assert.Equal(t, http.StatusCreated, rw.statusCode)
71+
assert.Equal(t, http.StatusCreated, rw.statusCode, "middleware should capture status code set by handler")
7972
}
8073
}
8174

@@ -107,8 +100,7 @@ func TestResponseWriter(t *testing.T) {
107100
statusCode: http.StatusOK,
108101
}
109102

110-
// No WriteHeader called yet
111-
assert.Equal(t, http.StatusOK, rw.Status())
103+
assert.Equal(t, http.StatusOK, rw.Status(), "should return initial status code when no WriteHeader called")
112104

113105
rw.WriteHeader(http.StatusTeapot)
114106
assert.Equal(t, http.StatusTeapot, rw.Status())
@@ -121,8 +113,7 @@ func TestResponseWriter(t *testing.T) {
121113
statusCode: http.StatusOK,
122114
}
123115

124-
// No bytes written yet
125-
assert.Equal(t, 0, rw.BytesWritten())
116+
assert.Equal(t, 0, rw.BytesWritten(), "should return 0 bytes when nothing written")
126117

127118
_, err := rw.Write([]byte("foo"))
128119
assert.NoError(t, err)
@@ -139,4 +130,86 @@ func TestResponseWriter(t *testing.T) {
139130
assert.Equal(t, 0, rw.Status())
140131
assert.Equal(t, 0, rw.BytesWritten())
141132
})
133+
134+
t.Run("Write() method error handling", func(t *testing.T) {
135+
// Create a mock ResponseWriter that returns an error
136+
mockWriter := &errorResponseWriter{err: errors.New("write failed")}
137+
rw := &ResponseWriter{
138+
ResponseWriter: mockWriter,
139+
}
140+
141+
n, err := rw.Write([]byte("test"))
142+
assert.Error(t, err)
143+
assert.Equal(t, "write failed", err.Error())
144+
assert.Equal(t, 0, n)
145+
assert.Equal(t, 0, rw.BytesWritten(), "no bytes should be counted on error")
146+
assert.True(t, rw.written, "written flag should still be set")
147+
})
148+
149+
t.Run("Status() without explicit WriteHeader()", func(t *testing.T) {
150+
rec := httptest.NewRecorder()
151+
rw := &ResponseWriter{
152+
ResponseWriter: rec,
153+
}
154+
155+
// Call Write() without WriteHeader() - should keep default status (0)
156+
_, err := rw.Write([]byte("test"))
157+
assert.NoError(t, err)
158+
assert.Equal(t, 0, rw.Status(), "should remain 0 (default) when WriteHeader not called")
159+
assert.Equal(t, 4, rw.BytesWritten(), "should count bytes written")
160+
assert.True(t, rw.written, "should set written flag")
161+
})
162+
163+
t.Run("Byte counting with partial writes", func(t *testing.T) {
164+
// Create a mock that only writes part of the data
165+
mockWriter := &partialResponseWriter{written: 2}
166+
rw := &ResponseWriter{
167+
ResponseWriter: mockWriter,
168+
}
169+
170+
n, err := rw.Write([]byte("test"))
171+
assert.NoError(t, err)
172+
assert.Equal(t, 2, n, "only 2 bytes should be written by partial writer")
173+
assert.Equal(t, 2, rw.BytesWritten(), "should track actual bytes written")
174+
assert.True(t, rw.written)
175+
176+
// Write again
177+
n, err = rw.Write([]byte("more"))
178+
assert.NoError(t, err)
179+
assert.Equal(t, 2, n, "only 2 bytes should be written again by partial writer")
180+
assert.Equal(t, 4, rw.BytesWritten(), "should accumulate bytes: 2 + 2 = 4")
181+
})
182+
}
183+
184+
// errorResponseWriter is a mock that always returns an error on Write
185+
type errorResponseWriter struct {
186+
err error
187+
}
188+
189+
func (e *errorResponseWriter) Header() http.Header {
190+
return make(http.Header)
142191
}
192+
193+
func (e *errorResponseWriter) Write([]byte) (int, error) {
194+
return 0, e.err
195+
}
196+
197+
func (e *errorResponseWriter) WriteHeader(int) {}
198+
199+
// partialResponseWriter is a mock that only writes a fixed number of bytes
200+
type partialResponseWriter struct {
201+
written int
202+
}
203+
204+
func (p *partialResponseWriter) Header() http.Header {
205+
return make(http.Header)
206+
}
207+
208+
func (p *partialResponseWriter) Write(data []byte) (int, error) {
209+
if len(data) < p.written {
210+
return len(data), nil
211+
}
212+
return p.written, nil
213+
}
214+
215+
func (p *partialResponseWriter) WriteHeader(int) {}

0 commit comments

Comments
 (0)