Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a custom error handler to key-auth middleware #1847

Merged
merged 3 commits into from
May 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions middleware/key_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,19 @@ type (
// Validator is a function to validate key.
// Required.
Validator KeyAuthValidator

// ErrorHandler defines a function which is executed for an invalid key.
// It may be used to define a custom error.
ErrorHandler KeyAuthErrorHandler
}

// KeyAuthValidator defines a function to validate KeyAuth credentials.
KeyAuthValidator func(string, echo.Context) (bool, error)

keyExtractor func(echo.Context) (string, error)

// KeyAuthErrorHandler defines a function which is executed for an invalid key.
KeyAuthErrorHandler func(error, echo.Context) error
)

var (
Expand Down Expand Up @@ -95,10 +102,16 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
// Extract and verify key
key, err := extractor(c)
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
}
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
valid, err := config.Validator(key, c)
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
}
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "invalid key",
Expand Down
260 changes: 210 additions & 50 deletions middleware/key_auth_test.go
Original file line number Diff line number Diff line change
@@ -1,75 +1,235 @@
package middleware

import (
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)

func testKeyValidator(key string, c echo.Context) (bool, error) {
switch key {
case "valid-key":
return true, nil
case "error-key":
return false, errors.New("some user defined error")
default:
return false, nil
}
}

func TestKeyAuth(t *testing.T) {
handlerCalled := false
handler := func(c echo.Context) error {
handlerCalled = true
return c.String(http.StatusOK, "test")
}
middlewareChain := KeyAuth(testKeyValidator)(handler)

e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
config := KeyAuthConfig{
Validator: func(key string, c echo.Context) (bool, error) {
return key == "valid-key", nil
},
}
h := KeyAuthWithConfig(config)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

assert := assert.New(t)
err := middlewareChain(c)

// Valid key
auth := DefaultKeyAuthConfig.AuthScheme + " " + "valid-key"
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
assert.NoError(t, err)
assert.True(t, handlerCalled)
}

// Invalid key
auth = DefaultKeyAuthConfig.AuthScheme + " " + "invalid-key"
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
func TestKeyAuthWithConfig(t *testing.T) {
var testCases = []struct {
name string
givenRequestFunc func() *http.Request
givenRequest func(req *http.Request)
whenConfig func(conf *KeyAuthConfig)
expectHandlerCalled bool
expectError string
}{
{
name: "ok, defaults, key from header",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
},
expectHandlerCalled: true,
},
{
name: "ok, custom skipper",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.Skipper = func(context echo.Context) bool {
return true
}
},
expectHandlerCalled: true,
},
{
name: "nok, defaults, invalid key from header, Authorization: Bearer",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
},
expectHandlerCalled: false,
expectError: "code=401, message=Unauthorized",
},
{
name: "nok, defaults, invalid scheme in header",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
},
expectHandlerCalled: false,
expectError: "code=400, message=invalid key in the request header",
},
{
name: "nok, defaults, missing header",
givenRequest: func(req *http.Request) {},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header",
},
{
name: "ok, custom key lookup, header",
givenRequest: func(req *http.Request) {
req.Header.Set("API-Key", "valid-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:API-Key"
},
expectHandlerCalled: true,
},
{
name: "nok, custom key lookup, missing header",
givenRequest: func(req *http.Request) {
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:API-Key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header",
},
{
name: "ok, custom key lookup, query",
givenRequest: func(req *http.Request) {
q := req.URL.Query()
q.Add("key", "valid-key")
req.URL.RawQuery = q.Encode()
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "query:key"
},
expectHandlerCalled: true,
},
{
name: "nok, custom key lookup, missing query param",
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "query:key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in the query string",
},
{
name: "ok, custom key lookup, form",
givenRequestFunc: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("key=valid-key"))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "form:key"
},
expectHandlerCalled: true,
},
{
name: "nok, custom key lookup, missing key in form",
givenRequestFunc: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("xxx=valid-key"))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "form:key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in the form",
},
{
name: "nok, custom errorHandler, error from extractor",
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:token"
conf.ErrorHandler = func(err error, context echo.Context) error {
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
httpError.Internal = err
return httpError
}
},
expectHandlerCalled: false,
expectError: "code=418, message=custom, internal=missing key in request header",
},
{
name: "nok, custom errorHandler, error from validator",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.ErrorHandler = func(err error, context echo.Context) error {
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
httpError.Internal = err
return httpError
}
},
expectHandlerCalled: false,
expectError: "code=418, message=custom, internal=some user defined error",
},
{
name: "nok, defaults, error from validator",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {},
expectHandlerCalled: false,
expectError: "code=401, message=invalid key, internal=some user defined error",
},
}

// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusBadRequest, he.Code)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handlerCalled := false
handler := func(c echo.Context) error {
handlerCalled = true
return c.String(http.StatusOK, "test")
}
config := KeyAuthConfig{
Validator: testKeyValidator,
}
if tc.whenConfig != nil {
tc.whenConfig(&config)
}
middlewareChain := KeyAuthWithConfig(config)(handler)

// Key from custom header
config.KeyLookup = "header:API-Key"
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
req.Header.Set("API-Key", "valid-key")
assert.NoError(h(c))
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequestFunc != nil {
req = tc.givenRequestFunc()
}
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

// Key from query string
config.KeyLookup = "query:key"
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
q := req.URL.Query()
q.Add("key", "valid-key")
req.URL.RawQuery = q.Encode()
assert.NoError(h(c))
err := middlewareChain(c)

// Key from form
config.KeyLookup = "form:key"
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
f := make(url.Values)
f.Set("key", "valid-key")
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
c = e.NewContext(req, rec)
assert.NoError(h(c))
assert.Equal(t, tc.expectHandlerCalled, handlerCalled)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}