Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion pkg/http/handler/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"io"
"net"
"net/http"
"slices"
"sync"
"time"

Expand Down Expand Up @@ -169,6 +170,9 @@ type timeoutWriter struct {
mu sync.Mutex
timedOut bool
lastWriteTime time.Time
// headers is a snapshot of headers taken when timeout occurs
// to prevent concurrent map access
headers http.Header
}

var (
Expand Down Expand Up @@ -201,7 +205,23 @@ func (tw *timeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return websocket.HijackIfPossible(tw.w)
}

func (tw *timeoutWriter) Header() http.Header { return tw.w.Header() }
func (tw *timeoutWriter) Header() http.Header {
tw.mu.Lock()
timedOut := tw.timedOut
headers := tw.headers
tw.mu.Unlock()

if timedOut {
// Return the snapshot of headers taken at timeout to prevent
// concurrent modification of the header map
if headers == nil {
// If no headers were captured, return an empty map
return make(http.Header)
}
return headers
}
return tw.w.Header()
}

func (tw *timeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock()
Expand Down Expand Up @@ -279,6 +299,13 @@ func (tw *timeoutWriter) tryIdleTimeoutAndWriteError(curTime time.Time, idleTime
}

func (tw *timeoutWriter) timeoutAndWriteError(msg string) {
// Capture a snapshot of headers before marking as timed out
// to prevent concurrent access to the underlying header map
tw.headers = make(http.Header)
for k, v := range tw.w.Header() {
tw.headers[k] = slices.Clone(v)
}

tw.w.WriteHeader(http.StatusGatewayTimeout)
Comment on lines +302 to 309
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% sure if I understood the timeoutHandler/Writer stuff correctly but wouldn't we still see a concurrent access to the map here?

io.WriteString(tw.w, msg)

Expand Down
94 changes: 94 additions & 0 deletions pkg/http/handler/timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -626,6 +627,99 @@ func BenchmarkTimeoutHandler(b *testing.B) {
})
}

func TestTimeoutHandlerConcurrentHeaderAccess(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure that this test would actually catch verify the problem?
I tried running it locally without the fix 1000 times but the test didn't fail.

// This test verifies the fix for the race condition when requests time out.
// It simulates the scenario where the timeout handler completes while the
// inner handler is still trying to modify headers. The key is that this
// should not panic with a concurrent map access error.

var completedCount atomic.Int32
var panicCount int32
innerHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate work that takes around the same time as timeout
time.Sleep(55 * time.Millisecond)

// After potential context cancellation, try to access headers
// This simulates what the error handler does
if r.Context().Err() != nil {
// Try to modify headers - this should not cause a panic
// even if timeout has occurred
w.Header().Set("X-Test-Header", "value")
http.Error(w, "context canceled", http.StatusBadGateway)
} else {
// If no timeout, write normally
w.WriteHeader(http.StatusOK)
}
completedCount.Add(1)
})

timeoutHandler := NewTimeoutHandler(
innerHandler,
"timeout",
func(r *http.Request) (time.Duration, time.Duration, time.Duration) {
return 50 * time.Millisecond, 0, 0
},
zaptest.NewLogger(t).Sugar(),
)

// Run multiple concurrent requests to increase chances of hitting the race
var wg sync.WaitGroup
var timeoutResponses atomic.Int32
var normalResponses atomic.Int32
for range 10 {
wg.Add(1)
go func() {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
// Should not panic with concurrent map access
atomic.AddInt32(&panicCount, 1)
t.Errorf("Unexpected panic: %v", r)
}
}()

req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Error(err)
return
}

rec := httptest.NewRecorder()

// This should not panic with concurrent map access
timeoutHandler.ServeHTTP(rec, req)

// We may get either a timeout or a normal response depending on timing
// The key is that we don't panic
switch rec.Code {
case http.StatusGatewayTimeout:
timeoutResponses.Add(1)
case http.StatusOK:
normalResponses.Add(1)
default:
t.Errorf("Unexpected status code: %d", rec.Code)
}
}()
}

wg.Wait()

// Give a bit more time for any lingering goroutines to complete
time.Sleep(100 * time.Millisecond)

// Check that no panics occurred
if panicCount > 0 {
t.Errorf("Got %d panics, expected 0", panicCount)
}

// At least some requests should have timed out
if timeoutResponses.Load() == 0 {
t.Error("Expected at least some timeout responses")
}

t.Logf("Got %d timeout responses and %d normal responses", timeoutResponses.Load(), normalResponses.Load())
}

func StaticTimeoutFunc(timeout time.Duration, requestStart time.Duration, idle time.Duration) TimeoutFunc {
return func(req *http.Request) (time.Duration, time.Duration, time.Duration) {
return timeout, requestStart, idle
Expand Down
Loading