Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions internal/config/transient_http_error_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package config

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

// TestIsTransientHTTPError verifies every status-code branch in isTransientHTTPError.
// The function returns true for HTTP 429 (TooManyRequests), 503 (ServiceUnavailable),
// and any 5xx status code, and false for all other codes.
func TestIsTransientHTTPError(t *testing.T) {
t.Parallel()

tests := []struct {
name string
statusCode int
want bool
}{
// True cases — rate limiting
{
name: "429 Too Many Requests is transient",
statusCode: http.StatusTooManyRequests,
want: true,
},
// True cases — service unavailable (also in 5xx range but named explicitly)
{
name: "503 Service Unavailable is transient",
statusCode: http.StatusServiceUnavailable,
want: true,
},
// True cases — full 5xx range
{
name: "500 Internal Server Error is transient",
statusCode: http.StatusInternalServerError,
want: true,
},
{
name: "501 Not Implemented is transient",
statusCode: http.StatusNotImplemented,
want: true,
},
{
name: "502 Bad Gateway is transient",
statusCode: http.StatusBadGateway,
want: true,
},
{
name: "504 Gateway Timeout is transient",
statusCode: http.StatusGatewayTimeout,
want: true,
},
{
name: "599 (max 5xx) is transient",
statusCode: 599,
want: true,
},
// False cases — successful responses
{
name: "200 OK is not transient",
statusCode: http.StatusOK,
want: false,
},
{
name: "201 Created is not transient",
statusCode: http.StatusCreated,
want: false,
},
{
name: "204 No Content is not transient",
statusCode: http.StatusNoContent,
want: false,
},
// False cases — redirects
{
name: "301 Moved Permanently is not transient",
statusCode: http.StatusMovedPermanently,
want: false,
},
// False cases — client errors (non-429)
{
name: "400 Bad Request is not transient",
statusCode: http.StatusBadRequest,
want: false,
},
{
name: "401 Unauthorized is not transient",
statusCode: http.StatusUnauthorized,
want: false,
},
{
name: "403 Forbidden is not transient",
statusCode: http.StatusForbidden,
want: false,
},
{
name: "404 Not Found is not transient",
statusCode: http.StatusNotFound,
want: false,
},
{
name: "422 Unprocessable Entity is not transient",
statusCode: http.StatusUnprocessableEntity,
want: false,
},
// Boundary: 499 is not transient (last 4xx)
{
name: "499 is not transient",
statusCode: 499,
want: false,
},
// Boundary: 600 is not transient (above 5xx)
{
name: "600 is not transient",
statusCode: 600,
want: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := isTransientHTTPError(tt.statusCode)
assert.Equal(t, tt.want, got, "isTransientHTTPError(%d)", tt.statusCode)
})
}
}
213 changes: 213 additions & 0 deletions internal/server/peek_request_body_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
package server

import (
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// errorOnReadReader is a test helper that returns an error on Read.
type errorOnReadReader struct {
readErr error
}

func (r *errorOnReadReader) Read(_ []byte) (int, error) { return 0, r.readErr }
func (r *errorOnReadReader) Close() error { return nil }

// errorOnCloseReader is a test helper that succeeds on Read but fails on Close.
type errorOnCloseReader struct {
data *bytes.Reader
closeErr error
}

func (r *errorOnCloseReader) Read(p []byte) (int, error) { return r.data.Read(p) }
func (r *errorOnCloseReader) Close() error { return r.closeErr }

// TestPeekRequestBody verifies all branches of peekRequestBody: non-POST methods,
// nil/NoBody bodies, read errors, close errors, empty body, and non-empty body with
// body-restoration behaviour.
func TestPeekRequestBody(t *testing.T) {
t.Parallel()

readErr := errors.New("simulated read error")
closeErr := errors.New("simulated close error")

tests := []struct {
name string
buildReq func() *http.Request
wantBytes []byte
wantErr error
checkBody bool // verify body is readable again after the call
wantBodyVal string
}{
{
name: "GET request returns nil without touching body",
buildReq: func() *http.Request {
return httptest.NewRequest(http.MethodGet, "/", bytes.NewBufferString("hello"))
},
wantBytes: nil,
wantErr: nil,
},
{
name: "PUT request returns nil without touching body",
buildReq: func() *http.Request {
return httptest.NewRequest(http.MethodPut, "/", bytes.NewBufferString("hello"))
},
wantBytes: nil,
wantErr: nil,
},
{
name: "DELETE request returns nil",
buildReq: func() *http.Request {
return httptest.NewRequest(http.MethodDelete, "/", nil)
},
wantBytes: nil,
wantErr: nil,
},
{
name: "POST with nil body returns nil",
buildReq: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Body = nil
return req
},
wantBytes: nil,
wantErr: nil,
},
Comment on lines +73 to +82
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "POST with nil body" case doesn’t actually exercise the r.Body == nil branch: httptest.NewRequest(..., nil) sets req.Body to http.NoBody (not nil), so this is effectively duplicating the explicit http.NoBody test. To cover the intended branch, explicitly set req.Body = nil after constructing the request (or rename/remove this case).

Copilot uses AI. Check for mistakes.
{
name: "POST with http.NoBody returns nil",
buildReq: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Body = http.NoBody
return req
},
wantBytes: nil,
wantErr: nil,
},
{
name: "POST with non-empty body returns bytes and restores body",
buildReq: func() *http.Request {
return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`{"method":"tools/list"}`))
},
wantBytes: []byte(`{"method":"tools/list"}`),
wantErr: nil,
checkBody: true,
wantBodyVal: `{"method":"tools/list"}`,
},
{
name: "POST with binary body restores body for re-reading",
buildReq: func() *http.Request {
content := []byte{0x00, 0x01, 0x02, 0xFF}
return httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(content))
},
wantBytes: []byte{0x00, 0x01, 0x02, 0xFF},
wantErr: nil,
checkBody: true,
},
{
name: "POST with empty body (reader at EOF) returns empty slice",
buildReq: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(""))
// httptest.NewRequest wraps an empty buffer in a ReadCloser rather than
// using http.NoBody, so this exercises the len(b)==0 branch.
return req
},
wantBytes: []byte{},
wantErr: nil,
},
{
name: "POST with read error propagates the error",
buildReq: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Body = &errorOnReadReader{readErr: readErr}
return req
},
wantBytes: nil,
wantErr: readErr,
},
{
name: "POST with close error propagates the error",
buildReq: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Body = &errorOnCloseReader{
data: bytes.NewReader([]byte("some content")),
closeErr: closeErr,
}
return req
},
wantBytes: nil,
wantErr: closeErr,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := tt.buildReq()

got, err := peekRequestBody(req)

if tt.wantErr != nil {
require.Error(t, err)
assert.ErrorIs(t, err, tt.wantErr)
return
}

require.NoError(t, err)
assert.Equal(t, tt.wantBytes, got)

if tt.checkBody {
// Verify that peekRequestBody restored the body so it can be read again.
require.NotNil(t, req.Body, "body should not be nil after peek")
assert.NotEqual(t, http.NoBody, req.Body, "body should be readable, not NoBody")

restored, readErr := io.ReadAll(req.Body)
require.NoError(t, readErr)

if tt.wantBodyVal != "" {
assert.Equal(t, tt.wantBodyVal, string(restored))
} else {
assert.Equal(t, tt.wantBytes, restored)
}
}
})
}
}

// TestPeekRequestBody_BodyRestoredForMultipleReads confirms the fundamental contract:
// after peekRequestBody returns, downstream handlers can still read the full body.
func TestPeekRequestBody_BodyRestoredForMultipleReads(t *testing.T) {
t.Parallel()

body := `{"jsonrpc":"2.0","method":"tools/call","id":1}`
req := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewBufferString(body))

// First peek
b1, err := peekRequestBody(req)
require.NoError(t, err)
assert.Equal(t, body, string(b1))

// Body must still be fully readable after the peek.
b2, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Equal(t, body, string(b2), "downstream handler should receive the complete body")
}

// TestPeekRequestBody_EmptyBodySetsNoBody confirms that when the body is empty the
// request body is replaced with http.NoBody (not a lingering empty reader).
func TestPeekRequestBody_EmptyBodySetsNoBody(t *testing.T) {
t.Parallel()

req := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewBufferString(""))

got, err := peekRequestBody(req)
require.NoError(t, err)
assert.Empty(t, got)
assert.Equal(t, http.NoBody, req.Body, "empty body should be replaced with http.NoBody")
}
Loading