Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context timeout middleware #2380

Merged
merged 11 commits into from
Feb 1, 2023
72 changes: 72 additions & 0 deletions middleware/context_timeout.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package middleware

import (
"context"
"errors"
"time"

"github.com/labstack/echo/v4"
)

// ContextTimeoutConfig defines the config for ContextTimeout middleware.
type ContextTimeoutConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper

// ErrorHandler is a function when error aries in middeware execution.
ErrorHandler func(err error, c echo.Context) error

// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
Timeout time.Duration
}

// ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client
// when underlying method returns context.DeadlineExceeded error.
func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc {
return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout})
}

// ContextTimeoutWithConfig returns a Timeout middleware with config.
func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc {
mw, err := config.ToMiddleware()
if err != nil {
panic(err)
}
return mw
}

// ToMiddleware converts Config to middleware.
func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Timeout == 0 {
return nil, errors.New("timeout must be set")
}
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
if config.ErrorHandler == nil {
config.ErrorHandler = func(err error, c echo.Context) error {
if err != nil && errors.Is(err, context.DeadlineExceeded) {
return echo.ErrServiceUnavailable.WithInternal(err)
}
return err
}
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}

timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
defer cancel()

c.SetRequest(c.Request().WithContext(timeoutContext))

if err := next(c); err != nil {
return config.ErrorHandler(err, c)
}
return nil
}
}, nil
}
226 changes: 226 additions & 0 deletions middleware/context_timeout_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
package middleware

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)

func TestContextTimeoutSkipper(t *testing.T) {
t.Parallel()
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
Skipper: func(context echo.Context) bool {
return true
},
Timeout: 10 * time.Millisecond,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
return err
}

return errors.New("response from handler")
})(c)

// if not skipped we would have not returned error due context timeout logic
assert.EqualError(t, err, "response from handler")
}

func TestContextTimeoutWithTimeout0(t *testing.T) {
t.Parallel()
assert.Panics(t, func() {
ContextTimeout(time.Duration(0))
})
}

func TestContextTimeoutErrorOutInHandler(t *testing.T) {
t.Parallel()
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 10 * time.Millisecond,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

rec.Code = 1 // we want to be sure that even 200 will not be sent
err := m(func(c echo.Context) error {
// this error must not be written to the client response. Middlewares upstream of timeout middleware must be able
// to handle returned error and this can be done only then handler has not yet committed (written status code)
// the response.
return echo.NewHTTPError(http.StatusTeapot, "err")
})(c)

assert.Error(t, err)
assert.EqualError(t, err, "code=418, message=err")
assert.Equal(t, 1, rec.Code)
assert.Equal(t, "", rec.Body.String())
}

func TestContextTimeoutSuccessfulRequest(t *testing.T) {
t.Parallel()
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 10 * time.Millisecond,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
return c.JSON(http.StatusCreated, map[string]string{"data": "ok"})
})(c)

assert.NoError(t, err)
assert.Equal(t, http.StatusCreated, rec.Code)
assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String())
}

func TestContextTimeoutTestRequestClone(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()

m := ContextTimeoutWithConfig(ContextTimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 1 * time.Second,
})

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
// Cookie test
cookie, err := c.Request().Cookie("cookie")
if assert.NoError(t, err) {
assert.EqualValues(t, "cookie", cookie.Name)
assert.EqualValues(t, "value", cookie.Value)
}

// Form values
if assert.NoError(t, c.Request().ParseForm()) {
assert.EqualValues(t, "value", c.Request().FormValue("form"))
}

// Query string
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
return nil
})(c)

assert.NoError(t, err)
}

func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) {
t.Parallel()

timeout := 10 * time.Millisecond
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
Timeout: timeout,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
return err
}
return c.String(http.StatusOK, "Hello, World!")
})(c)

assert.IsType(t, &echo.HTTPError{}, err)
assert.Error(t, err)
assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message)
}

func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) {
t.Parallel()

timeoutErrorHandler := func(err error, c echo.Context) error {
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return &echo.HTTPError{
Code: http.StatusServiceUnavailable,
Message: "Timeout! change me",
}
}
return err
}
return nil
}

timeout := 10 * time.Millisecond
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
Timeout: timeout,
ErrorHandler: timeoutErrorHandler,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable

if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
return err
}

// The Request Context should have a Deadline set by http.ContextTimeoutHandler
if _, ok := c.Request().Context().Deadline(); !ok {
assert.Fail(t, "No timeout set on Request Context")
}
return c.String(http.StatusOK, "Hello, World!")
})(c)

assert.IsType(t, &echo.HTTPError{}, err)
assert.Error(t, err)
assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message)
}

func sleepWithContext(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d)

defer func() {
_ = timer.Stop()
}()

select {
case <-ctx.Done():
return context.DeadlineExceeded
case <-timer.C:
return nil
}
}