Skip to content

Commit 73fa05f

Browse files
committed
Added panic recover middleware
Signed-off-by: Vishal Rana <vr@labstack.com>
1 parent 609879b commit 73fa05f

File tree

18 files changed

+166
-99
lines changed

18 files changed

+166
-99
lines changed

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,21 @@ func main() {
9090
// Echo instance
9191
e := echo.New()
9292

93+
//------------
9394
// Middleware
95+
//------------
96+
97+
// Recover
98+
e.Use(mw.Recover())
99+
100+
// Logger
94101
e.Use(mw.Logger())
95102

96103
// Routes
97104
e.Get("/", hello)
98105

99106
// Start server
100-
e.Run(":1323)
107+
e.Run(":1323")
101108
}
102109
```
103110

context.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77

88
type (
99
// Context represents context for the current request. It holds request and
10-
// response references, path parameters, data and registered handler.
10+
// response objects, path parameters, data and registered handler.
1111
Context struct {
1212
Request *http.Request
1313
Response *Response

echo.go

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ type (
2222
prefix string
2323
middleware []MiddlewareFunc
2424
maxParam byte
25-
notFoundHandler HandlerFunc
2625
httpErrorHandler HTTPErrorHandler
2726
binder BindFunc
2827
renderer Renderer
2928
uris map[Handler]string
3029
pool sync.Pool
30+
debug bool
3131
}
3232
HTTPError struct {
3333
Code int
@@ -115,8 +115,8 @@ var (
115115
// Errors
116116
//--------
117117

118-
UnsupportedMediaType = errors.New("echo: unsupported media type")
119-
RendererNotRegistered = errors.New("echo: renderer not registered")
118+
UnsupportedMediaType = errors.New("echo unsupported media type")
119+
RendererNotRegistered = errors.New("echo renderer not registered")
120120
)
121121

122122
// New creates an Echo instance.
@@ -134,19 +134,14 @@ func New() (e *Echo) {
134134
//----------
135135

136136
e.MaxParam(5)
137-
e.NotFoundHandler(func(c *Context) *HTTPError {
138-
http.Error(c.Response, http.StatusText(http.StatusNotFound), http.StatusNotFound)
139-
return nil
140-
})
141137
e.HTTPErrorHandler(func(he *HTTPError, c *Context) {
142138
if he.Code == 0 {
143139
he.Code = http.StatusInternalServerError
144140
}
145141
if he.Message == "" {
146-
if he.Error != nil {
142+
he.Message = http.StatusText(he.Code)
143+
if e.debug {
147144
he.Message = he.Error.Error()
148-
} else {
149-
he.Message = http.StatusText(he.Code)
150145
}
151146
}
152147
http.Error(c.Response, he.Message, he.Code)
@@ -185,12 +180,6 @@ func (e *Echo) MaxParam(n uint8) {
185180
e.maxParam = n
186181
}
187182

188-
// NotFoundHandler registers a custom NotFound handler used by router in case it
189-
// doesn't find any registered handler for HTTP method and path.
190-
func (e *Echo) NotFoundHandler(h Handler) {
191-
e.notFoundHandler = wrapHandler(h)
192-
}
193-
194183
// HTTPErrorHandler registers an HTTP error handler.
195184
func (e *Echo) HTTPErrorHandler(h HTTPErrorHandler) {
196185
e.httpErrorHandler = h
@@ -207,6 +196,11 @@ func (e *Echo) Renderer(r Renderer) {
207196
e.renderer = r
208197
}
209198

199+
// Debug runs the application in debug mode.
200+
func (e *Echo) Debug(on bool) {
201+
e.debug = on
202+
}
203+
210204
// Use adds handler to the middleware chain.
211205
func (e *Echo) Use(m ...Middleware) {
212206
for _, h := range m {
@@ -325,21 +319,20 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
325319
if echo != nil {
326320
e = echo
327321
}
328-
if h == nil {
329-
h = e.notFoundHandler
330-
}
331322
c.reset(w, r, e)
323+
if h == nil {
324+
c.Error(&HTTPError{Code: http.StatusNotFound})
325+
} else {
326+
// Chain middleware with handler in the end
327+
for i := len(e.middleware) - 1; i >= 0; i-- {
328+
h = e.middleware[i](h)
329+
}
332330

333-
// Chain middleware with handler in the end
334-
for i := len(e.middleware) - 1; i >= 0; i-- {
335-
h = e.middleware[i](h)
336-
}
337-
338-
// Execute chain
339-
if he := h(c); he != nil {
340-
e.httpErrorHandler(he, c)
331+
// Execute chain
332+
if he := h(c); he != nil {
333+
e.httpErrorHandler(he, c)
334+
}
341335
}
342-
343336
e.pool.Put(c)
344337
}
345338

@@ -394,7 +387,7 @@ func wrapMiddleware(m Middleware) MiddlewareFunc {
394387
case func(http.ResponseWriter, *http.Request):
395388
return wrapHTTPHandlerFuncMW(m)
396389
default:
397-
panic("echo: unknown middleware")
390+
panic("echo unknown middleware")
398391
}
399392
}
400393

@@ -440,7 +433,7 @@ func wrapHandler(h Handler) HandlerFunc {
440433
return nil
441434
}
442435
default:
443-
panic("echo: unknown handler")
436+
panic("echo unknown handler")
444437
}
445438
}
446439

echo_test.go

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -285,16 +285,6 @@ func TestEchoNotFound(t *testing.T) {
285285
if w.Code != http.StatusNotFound {
286286
t.Errorf("status code should be 404, found %d", w.Code)
287287
}
288-
289-
// Customized NotFound handler
290-
e.NotFoundHandler(func(c *Context) *HTTPError {
291-
return c.String(http.StatusNotFound, "not found")
292-
})
293-
w = httptest.NewRecorder()
294-
e.ServeHTTP(w, r)
295-
if w.Body.String() != "not found" {
296-
t.Errorf("body should be `not found`")
297-
}
298288
}
299289

300290
func verifyUser(u2 *user, t *testing.T) {

examples/crud/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func main() {
6161
e := echo.New()
6262

6363
// Middleware
64+
e.Use(mw.Recover())
6465
e.Use(mw.Logger())
6566

6667
// Routes

examples/hello/server.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@ func main() {
1616
// Echo instance
1717
e := echo.New()
1818

19+
//------------
1920
// Middleware
21+
//------------
22+
23+
// Recover
24+
e.Use(mw.Recover())
25+
26+
// Logger
2027
e.Use(mw.Logger())
2128

2229
// Routes

examples/middleware/server.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,16 @@ func main() {
1616
// Echo instance
1717
e := echo.New()
1818

19+
// Debug mode
20+
e.Debug(true)
21+
1922
//------------
2023
// Middleware
2124
//------------
2225

26+
// Recover
27+
e.Use(mw.Recover())
28+
2329
// Logger
2430
e.Use(mw.Logger())
2531

examples/web/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ func main() {
6565
e := echo.New()
6666

6767
// Middleware
68+
e.Use(mw.Recover())
6869
e.Use(mw.Logger())
6970

7071
//------------------------

middleware/auth.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ const (
1414
Basic = "Basic"
1515
)
1616

17-
// BasicAuth provides HTTP basic authentication.
17+
// BasicAuth returns an HTTP basic authentication middleware.
1818
func BasicAuth(fn AuthFunc) echo.HandlerFunc {
1919
return func(c *echo.Context) (he *echo.HTTPError) {
2020
auth := c.Request.Header.Get(echo.Authorization)

middleware/auth_test.go

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@ package middleware
22

33
import (
44
"encoding/base64"
5-
"github.com/labstack/echo"
65
"net/http"
7-
"net/http/httptest"
86
"testing"
7+
8+
"github.com/labstack/echo"
99
)
1010

1111
func TestBasicAuth(t *testing.T) {
1212
req, _ := http.NewRequest(echo.POST, "/", nil)
13-
res := &echo.Response{Writer: httptest.NewRecorder()}
13+
res := &echo.Response{}
1414
c := echo.NewContext(req, res, echo.New())
1515
fn := func(u, p string) bool {
1616
if u == "joe" && p == "secret" {
@@ -34,7 +34,7 @@ func TestBasicAuth(t *testing.T) {
3434
auth = "basic " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
3535
req.Header.Set(echo.Authorization, auth)
3636
if ba(c) != nil {
37-
t.Error("expected `pass` with case insensitive header")
37+
t.Error("expected `pass`, with case insensitive header.")
3838
}
3939

4040
//---------------------
@@ -46,29 +46,30 @@ func TestBasicAuth(t *testing.T) {
4646
req.Header.Set(echo.Authorization, auth)
4747
ba = BasicAuth(fn)
4848
if ba(c) == nil {
49-
t.Error("expected `fail` with incorrect password")
49+
t.Error("expected `fail`, with incorrect password.")
50+
}
51+
52+
// Empty Authorization header
53+
req.Header.Set(echo.Authorization, "")
54+
ba = BasicAuth(fn)
55+
if ba(c) == nil {
56+
t.Error("expected `fail`, with empty Authorization header.")
5057
}
5158

52-
// Invalid header
59+
// Invalid Authorization header
5360
auth = base64.StdEncoding.EncodeToString([]byte(" :secret"))
5461
req.Header.Set(echo.Authorization, auth)
5562
ba = BasicAuth(fn)
5663
if ba(c) == nil {
57-
t.Error("expected `fail` with invalid auth header")
64+
t.Error("expected `fail`, with invalid Authorization header.")
5865
}
5966

6067
// Invalid scheme
6168
auth = "Base " + base64.StdEncoding.EncodeToString([]byte(" :secret"))
6269
req.Header.Set(echo.Authorization, auth)
6370
ba = BasicAuth(fn)
6471
if ba(c) == nil {
65-
t.Error("expected `fail` with invalid scheme")
72+
t.Error("expected `fail`, with invalid scheme.")
6673
}
6774

68-
// Empty auth header
69-
req.Header.Set(echo.Authorization, "")
70-
ba = BasicAuth(fn)
71-
if ba(c) == nil {
72-
t.Error("expected `fail` with empty auth header")
73-
}
7475
}

middleware/compress.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,21 @@ func (g gzipWriter) Write(b []byte) (int, error) {
1919
return g.Writer.Write(b)
2020
}
2121

22-
// Gzip compresses HTTP response using gzip compression scheme.
22+
// Gzip returns a middleware which compresses HTTP response using gzip compression
23+
// scheme.
2324
func Gzip() echo.MiddlewareFunc {
2425
scheme := "gzip"
2526

2627
return func(h echo.HandlerFunc) echo.HandlerFunc {
2728
return func(c *echo.Context) *echo.HTTPError {
28-
if !strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) {
29-
return nil
29+
if strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) {
30+
w := gzip.NewWriter(c.Response.Writer)
31+
defer w.Close()
32+
gw := gzipWriter{Writer: w, Response: c.Response}
33+
c.Response.Header().Set(echo.ContentEncoding, scheme)
34+
c.Response = &echo.Response{Writer: gw}
3035
}
31-
32-
w := gzip.NewWriter(c.Response.Writer)
33-
defer w.Close()
34-
gw := gzipWriter{Writer: w, Response: c.Response}
35-
c.Response.Header().Set(echo.ContentEncoding, scheme)
36-
c.Response = &echo.Response{Writer: gw}
37-
if he := h(c); he != nil {
38-
c.Error(he)
39-
}
40-
return nil
36+
return h(c)
4137
}
4238
}
4339
}

middleware/compress_test.go

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,52 @@
11
package middleware
22

33
import (
4+
"compress/gzip"
5+
"io/ioutil"
46
"net/http"
57
"net/http/httptest"
68
"testing"
79

8-
"compress/gzip"
910
"github.com/labstack/echo"
10-
"io/ioutil"
1111
)
1212

1313
func TestGzip(t *testing.T) {
14+
// Empty Accept-Encoding header
1415
req, _ := http.NewRequest(echo.GET, "/", nil)
15-
req.Header.Set(echo.AcceptEncoding, "gzip")
1616
w := httptest.NewRecorder()
1717
res := &echo.Response{Writer: w}
1818
c := echo.NewContext(req, res, echo.New())
19-
Gzip()(func(c *echo.Context) *echo.HTTPError {
19+
h := func(c *echo.Context) *echo.HTTPError {
2020
return c.String(http.StatusOK, "test")
21-
})(c)
21+
}
22+
Gzip()(h)(c)
23+
s := w.Body.String()
24+
if s != "test" {
25+
t.Errorf("expected `test`, with empty Accept-Encoding header, got %s.", s)
26+
}
2227

23-
if w.Header().Get(echo.ContentEncoding) != "gzip" {
24-
t.Errorf("expected Content-Encoding header `gzip`, got %d.", w.Header().Get(echo.ContentEncoding))
28+
// Content-Encoding header
29+
req.Header.Set(echo.AcceptEncoding, "gzip")
30+
w = httptest.NewRecorder()
31+
c.Response = &echo.Response{Writer: w}
32+
Gzip()(h)(c)
33+
ce := w.Header().Get(echo.ContentEncoding)
34+
if ce != "gzip" {
35+
t.Errorf("expected Content-Encoding header `gzip`, got %d.", ce)
2536
}
2637

38+
// Body
2739
r, err := gzip.NewReader(w.Body)
2840
defer r.Close()
2941
if err != nil {
3042
t.Error(err)
3143
}
32-
3344
b, err := ioutil.ReadAll(r)
3445
if err != nil {
3546
t.Error(err)
3647
}
37-
s := string(b)
38-
48+
s = string(b)
3949
if s != "test" {
40-
t.Errorf("expected `test`, got %s.", s)
50+
t.Errorf("expected body `test`, got %s.", s)
4151
}
4252
}

0 commit comments

Comments
 (0)