Skip to content

Commit

Permalink
add a custom error handler to key-auth middleware (#1847)
Browse files Browse the repository at this point in the history
* add a custom error handler to key-auth middleware
  • Loading branch information
hyacinthus authored May 8, 2021
1 parent 76f186a commit 7256cb2
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 50 deletions.
13 changes: 13 additions & 0 deletions middleware/key_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,19 @@ type (
// Validator is a function to validate key.
// Required.
Validator KeyAuthValidator

// ErrorHandler defines a function which is executed for an invalid key.
// It may be used to define a custom error.
ErrorHandler KeyAuthErrorHandler
}

// KeyAuthValidator defines a function to validate KeyAuth credentials.
KeyAuthValidator func(string, echo.Context) (bool, error)

keyExtractor func(echo.Context) (string, error)

// KeyAuthErrorHandler defines a function which is executed for an invalid key.
KeyAuthErrorHandler func(error, echo.Context) error
)

var (
Expand Down Expand Up @@ -95,10 +102,16 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
// Extract and verify key
key, err := extractor(c)
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
}
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
valid, err := config.Validator(key, c)
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
}
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "invalid key",
Expand Down
260 changes: 210 additions & 50 deletions middleware/key_auth_test.go
Original file line number Diff line number Diff line change
@@ -1,75 +1,235 @@
package middleware

import (
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)

func testKeyValidator(key string, c echo.Context) (bool, error) {
switch key {
case "valid-key":
return true, nil
case "error-key":
return false, errors.New("some user defined error")
default:
return false, nil
}
}

func TestKeyAuth(t *testing.T) {
handlerCalled := false
handler := func(c echo.Context) error {
handlerCalled = true
return c.String(http.StatusOK, "test")
}
middlewareChain := KeyAuth(testKeyValidator)(handler)

e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
config := KeyAuthConfig{
Validator: func(key string, c echo.Context) (bool, error) {
return key == "valid-key", nil
},
}
h := KeyAuthWithConfig(config)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

assert := assert.New(t)
err := middlewareChain(c)

// Valid key
auth := DefaultKeyAuthConfig.AuthScheme + " " + "valid-key"
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
assert.NoError(t, err)
assert.True(t, handlerCalled)
}

// Invalid key
auth = DefaultKeyAuthConfig.AuthScheme + " " + "invalid-key"
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
func TestKeyAuthWithConfig(t *testing.T) {
var testCases = []struct {
name string
givenRequestFunc func() *http.Request
givenRequest func(req *http.Request)
whenConfig func(conf *KeyAuthConfig)
expectHandlerCalled bool
expectError string
}{
{
name: "ok, defaults, key from header",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
},
expectHandlerCalled: true,
},
{
name: "ok, custom skipper",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.Skipper = func(context echo.Context) bool {
return true
}
},
expectHandlerCalled: true,
},
{
name: "nok, defaults, invalid key from header, Authorization: Bearer",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
},
expectHandlerCalled: false,
expectError: "code=401, message=Unauthorized",
},
{
name: "nok, defaults, invalid scheme in header",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
},
expectHandlerCalled: false,
expectError: "code=400, message=invalid key in the request header",
},
{
name: "nok, defaults, missing header",
givenRequest: func(req *http.Request) {},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header",
},
{
name: "ok, custom key lookup, header",
givenRequest: func(req *http.Request) {
req.Header.Set("API-Key", "valid-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:API-Key"
},
expectHandlerCalled: true,
},
{
name: "nok, custom key lookup, missing header",
givenRequest: func(req *http.Request) {
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:API-Key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header",
},
{
name: "ok, custom key lookup, query",
givenRequest: func(req *http.Request) {
q := req.URL.Query()
q.Add("key", "valid-key")
req.URL.RawQuery = q.Encode()
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "query:key"
},
expectHandlerCalled: true,
},
{
name: "nok, custom key lookup, missing query param",
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "query:key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in the query string",
},
{
name: "ok, custom key lookup, form",
givenRequestFunc: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("key=valid-key"))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "form:key"
},
expectHandlerCalled: true,
},
{
name: "nok, custom key lookup, missing key in form",
givenRequestFunc: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("xxx=valid-key"))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "form:key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in the form",
},
{
name: "nok, custom errorHandler, error from extractor",
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:token"
conf.ErrorHandler = func(err error, context echo.Context) error {
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
httpError.Internal = err
return httpError
}
},
expectHandlerCalled: false,
expectError: "code=418, message=custom, internal=missing key in request header",
},
{
name: "nok, custom errorHandler, error from validator",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.ErrorHandler = func(err error, context echo.Context) error {
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
httpError.Internal = err
return httpError
}
},
expectHandlerCalled: false,
expectError: "code=418, message=custom, internal=some user defined error",
},
{
name: "nok, defaults, error from validator",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {},
expectHandlerCalled: false,
expectError: "code=401, message=invalid key, internal=some user defined error",
},
}

// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusBadRequest, he.Code)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handlerCalled := false
handler := func(c echo.Context) error {
handlerCalled = true
return c.String(http.StatusOK, "test")
}
config := KeyAuthConfig{
Validator: testKeyValidator,
}
if tc.whenConfig != nil {
tc.whenConfig(&config)
}
middlewareChain := KeyAuthWithConfig(config)(handler)

// Key from custom header
config.KeyLookup = "header:API-Key"
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
req.Header.Set("API-Key", "valid-key")
assert.NoError(h(c))
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequestFunc != nil {
req = tc.givenRequestFunc()
}
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

// Key from query string
config.KeyLookup = "query:key"
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
q := req.URL.Query()
q.Add("key", "valid-key")
req.URL.RawQuery = q.Encode()
assert.NoError(h(c))
err := middlewareChain(c)

// Key from form
config.KeyLookup = "form:key"
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
f := make(url.Values)
f.Set("key", "valid-key")
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
c = e.NewContext(req, rec)
assert.NoError(h(c))
assert.Equal(t, tc.expectHandlerCalled, handlerCalled)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}

0 comments on commit 7256cb2

Please sign in to comment.