Skip to content

Commit 7256cb2

Browse files
authored
add a custom error handler to key-auth middleware (#1847)
* add a custom error handler to key-auth middleware
1 parent 76f186a commit 7256cb2

File tree

2 files changed

+223
-50
lines changed

2 files changed

+223
-50
lines changed

middleware/key_auth.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,19 @@ type (
3030
// Validator is a function to validate key.
3131
// Required.
3232
Validator KeyAuthValidator
33+
34+
// ErrorHandler defines a function which is executed for an invalid key.
35+
// It may be used to define a custom error.
36+
ErrorHandler KeyAuthErrorHandler
3337
}
3438

3539
// KeyAuthValidator defines a function to validate KeyAuth credentials.
3640
KeyAuthValidator func(string, echo.Context) (bool, error)
3741

3842
keyExtractor func(echo.Context) (string, error)
43+
44+
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
45+
KeyAuthErrorHandler func(error, echo.Context) error
3946
)
4047

4148
var (
@@ -95,10 +102,16 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
95102
// Extract and verify key
96103
key, err := extractor(c)
97104
if err != nil {
105+
if config.ErrorHandler != nil {
106+
return config.ErrorHandler(err, c)
107+
}
98108
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
99109
}
100110
valid, err := config.Validator(key, c)
101111
if err != nil {
112+
if config.ErrorHandler != nil {
113+
return config.ErrorHandler(err, c)
114+
}
102115
return &echo.HTTPError{
103116
Code: http.StatusUnauthorized,
104117
Message: "invalid key",

middleware/key_auth_test.go

Lines changed: 210 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,235 @@
11
package middleware
22

33
import (
4+
"errors"
45
"net/http"
56
"net/http/httptest"
6-
"net/url"
77
"strings"
88
"testing"
99

1010
"github.com/labstack/echo/v4"
1111
"github.com/stretchr/testify/assert"
1212
)
1313

14+
func testKeyValidator(key string, c echo.Context) (bool, error) {
15+
switch key {
16+
case "valid-key":
17+
return true, nil
18+
case "error-key":
19+
return false, errors.New("some user defined error")
20+
default:
21+
return false, nil
22+
}
23+
}
24+
1425
func TestKeyAuth(t *testing.T) {
26+
handlerCalled := false
27+
handler := func(c echo.Context) error {
28+
handlerCalled = true
29+
return c.String(http.StatusOK, "test")
30+
}
31+
middlewareChain := KeyAuth(testKeyValidator)(handler)
32+
1533
e := echo.New()
1634
req := httptest.NewRequest(http.MethodGet, "/", nil)
35+
req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
1736
rec := httptest.NewRecorder()
1837
c := e.NewContext(req, rec)
19-
config := KeyAuthConfig{
20-
Validator: func(key string, c echo.Context) (bool, error) {
21-
return key == "valid-key", nil
22-
},
23-
}
24-
h := KeyAuthWithConfig(config)(func(c echo.Context) error {
25-
return c.String(http.StatusOK, "test")
26-
})
2738

28-
assert := assert.New(t)
39+
err := middlewareChain(c)
2940

30-
// Valid key
31-
auth := DefaultKeyAuthConfig.AuthScheme + " " + "valid-key"
32-
req.Header.Set(echo.HeaderAuthorization, auth)
33-
assert.NoError(h(c))
41+
assert.NoError(t, err)
42+
assert.True(t, handlerCalled)
43+
}
3444

35-
// Invalid key
36-
auth = DefaultKeyAuthConfig.AuthScheme + " " + "invalid-key"
37-
req.Header.Set(echo.HeaderAuthorization, auth)
38-
he := h(c).(*echo.HTTPError)
39-
assert.Equal(http.StatusUnauthorized, he.Code)
45+
func TestKeyAuthWithConfig(t *testing.T) {
46+
var testCases = []struct {
47+
name string
48+
givenRequestFunc func() *http.Request
49+
givenRequest func(req *http.Request)
50+
whenConfig func(conf *KeyAuthConfig)
51+
expectHandlerCalled bool
52+
expectError string
53+
}{
54+
{
55+
name: "ok, defaults, key from header",
56+
givenRequest: func(req *http.Request) {
57+
req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
58+
},
59+
expectHandlerCalled: true,
60+
},
61+
{
62+
name: "ok, custom skipper",
63+
givenRequest: func(req *http.Request) {
64+
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
65+
},
66+
whenConfig: func(conf *KeyAuthConfig) {
67+
conf.Skipper = func(context echo.Context) bool {
68+
return true
69+
}
70+
},
71+
expectHandlerCalled: true,
72+
},
73+
{
74+
name: "nok, defaults, invalid key from header, Authorization: Bearer",
75+
givenRequest: func(req *http.Request) {
76+
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
77+
},
78+
expectHandlerCalled: false,
79+
expectError: "code=401, message=Unauthorized",
80+
},
81+
{
82+
name: "nok, defaults, invalid scheme in header",
83+
givenRequest: func(req *http.Request) {
84+
req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
85+
},
86+
expectHandlerCalled: false,
87+
expectError: "code=400, message=invalid key in the request header",
88+
},
89+
{
90+
name: "nok, defaults, missing header",
91+
givenRequest: func(req *http.Request) {},
92+
expectHandlerCalled: false,
93+
expectError: "code=400, message=missing key in request header",
94+
},
95+
{
96+
name: "ok, custom key lookup, header",
97+
givenRequest: func(req *http.Request) {
98+
req.Header.Set("API-Key", "valid-key")
99+
},
100+
whenConfig: func(conf *KeyAuthConfig) {
101+
conf.KeyLookup = "header:API-Key"
102+
},
103+
expectHandlerCalled: true,
104+
},
105+
{
106+
name: "nok, custom key lookup, missing header",
107+
givenRequest: func(req *http.Request) {
108+
},
109+
whenConfig: func(conf *KeyAuthConfig) {
110+
conf.KeyLookup = "header:API-Key"
111+
},
112+
expectHandlerCalled: false,
113+
expectError: "code=400, message=missing key in request header",
114+
},
115+
{
116+
name: "ok, custom key lookup, query",
117+
givenRequest: func(req *http.Request) {
118+
q := req.URL.Query()
119+
q.Add("key", "valid-key")
120+
req.URL.RawQuery = q.Encode()
121+
},
122+
whenConfig: func(conf *KeyAuthConfig) {
123+
conf.KeyLookup = "query:key"
124+
},
125+
expectHandlerCalled: true,
126+
},
127+
{
128+
name: "nok, custom key lookup, missing query param",
129+
whenConfig: func(conf *KeyAuthConfig) {
130+
conf.KeyLookup = "query:key"
131+
},
132+
expectHandlerCalled: false,
133+
expectError: "code=400, message=missing key in the query string",
134+
},
135+
{
136+
name: "ok, custom key lookup, form",
137+
givenRequestFunc: func() *http.Request {
138+
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("key=valid-key"))
139+
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
140+
return req
141+
},
142+
whenConfig: func(conf *KeyAuthConfig) {
143+
conf.KeyLookup = "form:key"
144+
},
145+
expectHandlerCalled: true,
146+
},
147+
{
148+
name: "nok, custom key lookup, missing key in form",
149+
givenRequestFunc: func() *http.Request {
150+
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("xxx=valid-key"))
151+
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
152+
return req
153+
},
154+
whenConfig: func(conf *KeyAuthConfig) {
155+
conf.KeyLookup = "form:key"
156+
},
157+
expectHandlerCalled: false,
158+
expectError: "code=400, message=missing key in the form",
159+
},
160+
{
161+
name: "nok, custom errorHandler, error from extractor",
162+
whenConfig: func(conf *KeyAuthConfig) {
163+
conf.KeyLookup = "header:token"
164+
conf.ErrorHandler = func(err error, context echo.Context) error {
165+
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
166+
httpError.Internal = err
167+
return httpError
168+
}
169+
},
170+
expectHandlerCalled: false,
171+
expectError: "code=418, message=custom, internal=missing key in request header",
172+
},
173+
{
174+
name: "nok, custom errorHandler, error from validator",
175+
givenRequest: func(req *http.Request) {
176+
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
177+
},
178+
whenConfig: func(conf *KeyAuthConfig) {
179+
conf.ErrorHandler = func(err error, context echo.Context) error {
180+
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
181+
httpError.Internal = err
182+
return httpError
183+
}
184+
},
185+
expectHandlerCalled: false,
186+
expectError: "code=418, message=custom, internal=some user defined error",
187+
},
188+
{
189+
name: "nok, defaults, error from validator",
190+
givenRequest: func(req *http.Request) {
191+
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
192+
},
193+
whenConfig: func(conf *KeyAuthConfig) {},
194+
expectHandlerCalled: false,
195+
expectError: "code=401, message=invalid key, internal=some user defined error",
196+
},
197+
}
40198

41-
// Missing Authorization header
42-
req.Header.Del(echo.HeaderAuthorization)
43-
he = h(c).(*echo.HTTPError)
44-
assert.Equal(http.StatusBadRequest, he.Code)
199+
for _, tc := range testCases {
200+
t.Run(tc.name, func(t *testing.T) {
201+
handlerCalled := false
202+
handler := func(c echo.Context) error {
203+
handlerCalled = true
204+
return c.String(http.StatusOK, "test")
205+
}
206+
config := KeyAuthConfig{
207+
Validator: testKeyValidator,
208+
}
209+
if tc.whenConfig != nil {
210+
tc.whenConfig(&config)
211+
}
212+
middlewareChain := KeyAuthWithConfig(config)(handler)
45213

46-
// Key from custom header
47-
config.KeyLookup = "header:API-Key"
48-
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
49-
return c.String(http.StatusOK, "test")
50-
})
51-
req.Header.Set("API-Key", "valid-key")
52-
assert.NoError(h(c))
214+
e := echo.New()
215+
req := httptest.NewRequest(http.MethodGet, "/", nil)
216+
if tc.givenRequestFunc != nil {
217+
req = tc.givenRequestFunc()
218+
}
219+
if tc.givenRequest != nil {
220+
tc.givenRequest(req)
221+
}
222+
rec := httptest.NewRecorder()
223+
c := e.NewContext(req, rec)
53224

54-
// Key from query string
55-
config.KeyLookup = "query:key"
56-
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
57-
return c.String(http.StatusOK, "test")
58-
})
59-
q := req.URL.Query()
60-
q.Add("key", "valid-key")
61-
req.URL.RawQuery = q.Encode()
62-
assert.NoError(h(c))
225+
err := middlewareChain(c)
63226

64-
// Key from form
65-
config.KeyLookup = "form:key"
66-
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
67-
return c.String(http.StatusOK, "test")
68-
})
69-
f := make(url.Values)
70-
f.Set("key", "valid-key")
71-
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
72-
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
73-
c = e.NewContext(req, rec)
74-
assert.NoError(h(c))
227+
assert.Equal(t, tc.expectHandlerCalled, handlerCalled)
228+
if tc.expectError != "" {
229+
assert.EqualError(t, err, tc.expectError)
230+
} else {
231+
assert.NoError(t, err)
232+
}
233+
})
234+
}
75235
}

0 commit comments

Comments
 (0)