Skip to content

replace the custom HeaderMap type with http.Header from stdlib in the header middleware #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions examples/headers_middleware/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,23 @@ func buildRoutes(logHandler slog.Handler) ([]httpserver.Route, error) {
headersMw := headers.NewWithOperations(
// Request header operations (applied before handler)
headers.WithRemoveRequest("X-Forwarded-For", "X-Real-IP"), // Remove proxy headers
headers.WithSetRequest(headers.HeaderMap{
"X-Request-Source": "go-supervisor-example", // Set request source
headers.WithSetRequest(http.Header{
"X-Request-Source": []string{"go-supervisor-example"}, // Set request source
}),
headers.WithAddRequest(headers.HeaderMap{
"X-Internal-Request": "true", // Mark as internal
headers.WithAddRequest(http.Header{
"X-Internal-Request": []string{"true"}, // Mark as internal
}),
headers.WithAddRequestHeader("X-Processing-Time", time.Now().Format(time.RFC3339)),

// Response header operations (applied after handler)
headers.WithRemove("Server", "X-Powered-By"), // Remove server identification
headers.WithSet(headers.HeaderMap{
"X-Frame-Options": "DENY", // Security header
"X-API-Version": "v1.0", // API version
"Content-Type": "application/json", // JSON responses
headers.WithSet(http.Header{
"X-Frame-Options": []string{"DENY"}, // Security header
"X-API-Version": []string{"v1.0"}, // API version
"Content-Type": []string{"application/json"}, // JSON responses
}),
headers.WithAdd(headers.HeaderMap{
"X-Custom-Header": "go-supervisor-headers", // Custom header
headers.WithAdd(http.Header{
"X-Custom-Header": []string{"go-supervisor-headers"}, // Custom header
}),
headers.WithAddHeader("X-Response-Time", time.Now().Format(time.RFC3339)),
)
Expand Down
2 changes: 1 addition & 1 deletion runnables/httpserver/middleware/compliance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func createRequestProcessor(
func TestBuiltinMiddlewareCompliance(t *testing.T) {
// Test headers middleware
t.Run("headers middleware", func(t *testing.T) {
headersMiddleware := headers.New(headers.HeaderMap{"Content-Type": "application/json"})
headersMiddleware := headers.New(http.Header{"Content-Type": []string{"application/json"}})
test := NewMiddlewareComplianceTest(t, "headers", headersMiddleware)
test.RunAllTests()
})
Expand Down
41 changes: 21 additions & 20 deletions runnables/httpserver/middleware/headers/headers.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
package headers

import (
"net/http"

"github.com/robbyt/go-supervisor/runnables/httpserver"
)

// HeaderMap represents a collection of HTTP headers
type HeaderMap map[string]string

// New creates a middleware that sets HTTP headers on responses.
// Headers are set before the request is processed, allowing other middleware
// and handlers to override them if needed.
//
// Note: The Go standard library's http package will validate headers when
// writing them to prevent protocol violations. This middleware does not
// perform additional validation beyond what the standard library provides.
func New(headers HeaderMap) httpserver.HandlerFunc {
func New(headers http.Header) httpserver.HandlerFunc {
return func(rp *httpserver.RequestProcessor) {
// Set headers before processing
for key, value := range headers {
rp.Writer().Header().Set(key, value)
for key, values := range headers {
for _, value := range values {
rp.Writer().Header().Add(key, value)
}
}

// Continue processing
Expand All @@ -29,9 +30,9 @@ func New(headers HeaderMap) httpserver.HandlerFunc {
// JSON creates a middleware that sets JSON-specific headers.
// Sets Content-Type to application/json and Cache-Control to no-cache.
func JSON() httpserver.HandlerFunc {
return New(HeaderMap{
"Content-Type": "application/json",
"Cache-Control": "no-cache",
return New(http.Header{
"Content-Type": []string{"application/json"},
"Cache-Control": []string{"no-cache"},
})
}

Expand All @@ -56,32 +57,32 @@ func JSON() httpserver.HandlerFunc {
// // Development setup with all methods
// CORS("http://localhost:3000", "GET,POST,PUT,PATCH,DELETE,OPTIONS", "*")
func CORS(allowOrigin, allowMethods, allowHeaders string) httpserver.HandlerFunc {
corsHeaders := HeaderMap{
"Access-Control-Allow-Origin": allowOrigin,
"Access-Control-Allow-Methods": allowMethods,
"Access-Control-Allow-Headers": allowHeaders,
corsHeaders := http.Header{
"Access-Control-Allow-Origin": []string{allowOrigin},
"Access-Control-Allow-Methods": []string{allowMethods},
"Access-Control-Allow-Headers": []string{allowHeaders},
}

// Add credentials header if origin is not wildcard
if allowOrigin != "*" {
corsHeaders["Access-Control-Allow-Credentials"] = "true"
corsHeaders["Access-Control-Allow-Credentials"] = []string{"true"}
}

return NewWithOperations(WithSet(corsHeaders))
}

// Security creates a middleware that sets common security headers.
func Security() httpserver.HandlerFunc {
return New(HeaderMap{
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Referrer-Policy": "strict-origin-when-cross-origin",
return New(http.Header{
"X-Content-Type-Options": []string{"nosniff"},
"X-Frame-Options": []string{"DENY"},
"X-XSS-Protection": []string{"1; mode=block"},
"Referrer-Policy": []string{"strict-origin-when-cross-origin"},
})
}

// Add creates a middleware that adds a single header.
// This is useful for simple header additions.
func Add(key, value string) httpserver.HandlerFunc {
return New(HeaderMap{key: value})
return New(http.Header{key: []string{value}})
}
68 changes: 60 additions & 8 deletions runnables/httpserver/middleware/headers/headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ func TestNew(t *testing.T) {
t.Run("sets single header", func(t *testing.T) {
t.Parallel()

headers := HeaderMap{
"X-Test-Header": "test-value",
headers := http.Header{
"X-Test-Header": []string{"test-value"},
}

middleware := New(headers)
Expand Down Expand Up @@ -47,10 +47,10 @@ func TestNew(t *testing.T) {
t.Run("sets multiple headers", func(t *testing.T) {
t.Parallel()

headers := HeaderMap{
"X-Header-One": "value-one",
"X-Header-Two": "value-two",
"Content-Type": "application/json",
headers := http.Header{
"X-Header-One": []string{"value-one"},
"X-Header-Two": []string{"value-two"},
"Content-Type": []string{"application/json"},
}

middleware := New(headers)
Expand Down Expand Up @@ -82,7 +82,7 @@ func TestNew(t *testing.T) {
t.Run("handles empty headers map", func(t *testing.T) {
t.Parallel()

middleware := New(HeaderMap{})
middleware := New(http.Header{})

req := httptest.NewRequest("GET", "/test", nil)
rec := httptest.NewRecorder()
Expand All @@ -100,7 +100,7 @@ func TestNew(t *testing.T) {
t.Run("allows headers to be overridden by subsequent middleware", func(t *testing.T) {
t.Parallel()

headerMiddleware := New(HeaderMap{"X-Test": "original"})
headerMiddleware := New(http.Header{"X-Test": []string{"original"}})
overrideMiddleware := func(rp *httpserver.RequestProcessor) {
rp.Writer().Header().Set("X-Test", "overridden")
rp.Next()
Expand All @@ -118,6 +118,58 @@ func TestNew(t *testing.T) {

assert.Equal(t, "overridden", rec.Header().Get("X-Test"), "header should be overridden")
})

t.Run("multiple Set-Cookie headers remain separate", func(t *testing.T) {
headers := http.Header{
"Set-Cookie": []string{
"session=abc123; Path=/; HttpOnly",
"theme=dark; Path=/; Max-Age=86400",
"lang=en; Path=/; SameSite=Strict",
},
}

middleware := New(headers)

req := httptest.NewRequest("GET", "/test", nil)
rec := httptest.NewRecorder()

route, err := httpserver.NewRouteFromHandlerFunc("test", "/test",
func(w http.ResponseWriter, r *http.Request) {}, middleware)
assert.NoError(t, err, "route creation should not fail")

route.ServeHTTP(rec, req)

cookies := rec.Header().Values("Set-Cookie")
assert.Len(t, cookies, 3, "should have three separate Set-Cookie headers")
assert.Contains(t, cookies, "session=abc123; Path=/; HttpOnly")
assert.Contains(t, cookies, "theme=dark; Path=/; Max-Age=86400")
assert.Contains(t, cookies, "lang=en; Path=/; SameSite=Strict")
})

t.Run("other headers can be comma-combined", func(t *testing.T) {
headers := http.Header{
"Accept": []string{"text/html", "application/json", "application/xml"},
"Accept-Encoding": []string{"gzip", "deflate", "br"},
}

middleware := New(headers)

req := httptest.NewRequest("GET", "/test", nil)
rec := httptest.NewRecorder()

route, err := httpserver.NewRouteFromHandlerFunc("test", "/test",
func(w http.ResponseWriter, r *http.Request) {}, middleware)
assert.NoError(t, err, "route creation should not fail")

route.ServeHTTP(rec, req)

// These headers can be comma-combined
acceptValues := rec.Header().Values("Accept")
assert.Len(t, acceptValues, 3, "should have three Accept values")

encodingValues := rec.Header().Values("Accept-Encoding")
assert.Len(t, encodingValues, 3, "should have three Accept-Encoding values")
})
}

func TestJSON(t *testing.T) {
Expand Down
38 changes: 22 additions & 16 deletions runnables/httpserver/middleware/headers/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ type headerOperations struct {
}

// WithSet creates an operation to set (replace) headers
func WithSet(headers HeaderMap) HeaderOperation {
func WithSet(headers http.Header) HeaderOperation {
return func(ops *headerOperations) {
if ops.setHeaders == nil {
ops.setHeaders = make(http.Header)
}
for key, value := range headers {
ops.setHeaders.Set(key, value)
for key, values := range headers {
ops.setHeaders[key] = values
}
}
}
Expand All @@ -41,13 +41,15 @@ func WithSetHeader(key, value string) HeaderOperation {
}

// WithAdd creates an operation to add (append) headers
func WithAdd(headers HeaderMap) HeaderOperation {
func WithAdd(headers http.Header) HeaderOperation {
return func(ops *headerOperations) {
if ops.addHeaders == nil {
ops.addHeaders = make(http.Header)
}
for key, value := range headers {
ops.addHeaders.Add(key, value)
for key, values := range headers {
for _, value := range values {
ops.addHeaders.Add(key, value)
}
}
}
}
Expand All @@ -70,13 +72,13 @@ func WithRemove(headerNames ...string) HeaderOperation {
}

// WithSetRequest creates an operation to set (replace) request headers
func WithSetRequest(headers HeaderMap) HeaderOperation {
func WithSetRequest(headers http.Header) HeaderOperation {
return func(ops *headerOperations) {
if ops.setRequestHeaders == nil {
ops.setRequestHeaders = make(http.Header)
}
for key, value := range headers {
ops.setRequestHeaders.Set(key, value)
for key, values := range headers {
ops.setRequestHeaders[key] = values
}
}
}
Expand All @@ -92,13 +94,15 @@ func WithSetRequestHeader(key, value string) HeaderOperation {
}

// WithAddRequest creates an operation to add (append) request headers
func WithAddRequest(headers HeaderMap) HeaderOperation {
func WithAddRequest(headers http.Header) HeaderOperation {
return func(ops *headerOperations) {
if ops.addRequestHeaders == nil {
ops.addRequestHeaders = make(http.Header)
}
for key, value := range headers {
ops.addRequestHeaders.Add(key, value)
for key, values := range headers {
for _, value := range values {
ops.addRequestHeaders.Add(key, value)
}
}
}
}
Expand Down Expand Up @@ -140,8 +144,9 @@ func NewWithOperations(operations ...HeaderOperation) httpserver.HandlerFunc {

// 2. Set request headers (replace)
for key, values := range ops.setRequestHeaders {
if len(values) > 0 {
request.Header.Set(key, values[0])
request.Header.Del(key)
for _, value := range values {
request.Header.Add(key, value)
}
}

Expand All @@ -160,8 +165,9 @@ func NewWithOperations(operations ...HeaderOperation) httpserver.HandlerFunc {

// 2. Set response headers (replace)
for key, values := range ops.setHeaders {
if len(values) > 0 {
writer.Header().Set(key, values[0])
writer.Header().Del(key)
for _, value := range values {
writer.Header().Add(key, value)
}
}

Expand Down
Loading