Skip to content

make it easier for external middleware implementations #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion runnables/httpserver/middleware/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func Logger(logger *slog.Logger) Middleware {
start := time.Now()

// Create a response writer wrapper to capture status code
rw := &responseWriter{
rw := &ResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK, // Default status code
}
Expand Down
2 changes: 1 addition & 1 deletion runnables/httpserver/middleware/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func MetricCollector() Middleware {
start := time.Now()

// Create a response writer wrapper to capture status code
rw := &responseWriter{
rw := &ResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK, // Default status code
}
Expand Down
2 changes: 1 addition & 1 deletion runnables/httpserver/middleware/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestMetricCollector(t *testing.T) {
t.Run("captures status code in responseWriter", func(t *testing.T) {
// Setup direct responseWriter (without middleware)
rec := httptest.NewRecorder()
rw := &responseWriter{
rw := &ResponseWriter{
ResponseWriter: rec,
statusCode: http.StatusOK, // Default
}
Expand Down
41 changes: 28 additions & 13 deletions runnables/httpserver/middleware/middleware.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,46 @@
// Package middleware provides HTTP middleware utilities for wrapping http.HandlerFunc with additional functionality such as logging, metrics, and response inspection.
package middleware

import (
"net/http"
)

// Middleware is a function that takes an http.HandlerFunc and returns a new http.HandlerFunc
// which may execute code before and/or after calling the original handler.
// which may execute code before and/or after calling the original handler. It is commonly used
// for cross-cutting concerns such as logging, authentication, and metrics.
type Middleware func(http.HandlerFunc) http.HandlerFunc

// responseWriter is a wrapper for http.ResponseWriter that captures the status code.
type responseWriter struct {
// ResponseWriter is a wrapper for http.ResponseWriter that captures the status code and the number of bytes written.
// It is useful in middleware for logging, metrics, and conditional logic based on the response.
type ResponseWriter struct {
http.ResponseWriter
statusCode int
written bool
statusCode int
written bool
bytesWritten int
}

// WriteHeader captures the status code and calls the underlying WriteHeader.
func (rw *responseWriter) WriteHeader(statusCode int) {
rw.statusCode = statusCode
func (rw *ResponseWriter) WriteHeader(statusCode int) {
rw.written = true
rw.statusCode = statusCode
rw.ResponseWriter.WriteHeader(statusCode)
}

// Write captures that a response has been written and calls the underlying Write.
func (rw *responseWriter) Write(b []byte) (int, error) {
if !rw.written {
rw.written = true
}
return rw.ResponseWriter.Write(b)
// Write captures that a response has been written, counts the bytes, and calls the underlying Write.
func (rw *ResponseWriter) Write(b []byte) (int, error) {
rw.written = true
n, err := rw.ResponseWriter.Write(b)
rw.bytesWritten += n
return n, err
}

// Status returns the HTTP status code that was written to the response.
// If no status code was explicitly set, it returns 0.
func (rw *ResponseWriter) Status() int {
return rw.statusCode
}

// BytesWritten returns the total number of bytes written to the response body.
func (rw *ResponseWriter) BytesWritten() int {
return rw.bytesWritten
}
48 changes: 44 additions & 4 deletions runnables/httpserver/middleware/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func TestResponseWriter(t *testing.T) {
t.Run("WriteHeader sets status code", func(t *testing.T) {
rec := httptest.NewRecorder()
rw := &responseWriter{
rw := &ResponseWriter{
ResponseWriter: rec,
statusCode: http.StatusOK, // Default
}
Expand All @@ -30,7 +30,7 @@ func TestResponseWriter(t *testing.T) {

t.Run("Write sets written flag", func(t *testing.T) {
rec := httptest.NewRecorder()
rw := &responseWriter{
rw := &ResponseWriter{
ResponseWriter: rec,
statusCode: http.StatusOK,
written: false,
Expand All @@ -51,7 +51,7 @@ func TestResponseWriter(t *testing.T) {

t.Run("Write doesn't change written flag if already set", func(t *testing.T) {
rec := httptest.NewRecorder()
rw := &responseWriter{
rw := &ResponseWriter{
ResponseWriter: rec,
statusCode: http.StatusOK,
written: true, // Already set
Expand All @@ -69,7 +69,7 @@ func TestResponseWriter(t *testing.T) {
// Create a test middleware that uses the responseWriter
testMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
rw := &responseWriter{
rw := &ResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
Expand Down Expand Up @@ -99,4 +99,44 @@ func TestResponseWriter(t *testing.T) {
assert.Equal(t, http.StatusCreated, rec.Code)
assert.Equal(t, "Created", rec.Body.String())
})

t.Run("Status() returns correct status", func(t *testing.T) {
rec := httptest.NewRecorder()
rw := &ResponseWriter{
ResponseWriter: rec,
statusCode: http.StatusOK,
}

// No WriteHeader called yet
assert.Equal(t, http.StatusOK, rw.Status())

rw.WriteHeader(http.StatusTeapot)
assert.Equal(t, http.StatusTeapot, rw.Status())
})

t.Run("BytesWritten() returns correct count", func(t *testing.T) {
rec := httptest.NewRecorder()
rw := &ResponseWriter{
ResponseWriter: rec,
statusCode: http.StatusOK,
}

// No bytes written yet
assert.Equal(t, 0, rw.BytesWritten())

_, err := rw.Write([]byte("foo"))
assert.NoError(t, err)
_, err = rw.Write([]byte("barbaz"))
assert.NoError(t, err)
assert.Equal(t, 9, rw.BytesWritten())
})

t.Run("Status() and BytesWritten() default values", func(t *testing.T) {
rec := httptest.NewRecorder()
rw := &ResponseWriter{
ResponseWriter: rec,
}
assert.Equal(t, 0, rw.Status())
assert.Equal(t, 0, rw.BytesWritten())
})
}