From 14a7cc3436df0dd9a24dda30c183457f79b22620 Mon Sep 17 00:00:00 2001 From: nbjahan Date: Wed, 4 Dec 2013 14:29:04 +0330 Subject: [PATCH] Treat NotFound as a route --- router.go | 21 +++++++------- router_test.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 11 deletions(-) diff --git a/router.go b/router.go index 19bd79f..70b1995 100644 --- a/router.go +++ b/router.go @@ -29,21 +29,21 @@ type Router interface { // Any adds a route for any HTTP method request to the specified matching pattern. Any(string, ...Handler) Route - // NotFound sets the handler that is called when a no route matches a request. Throws a basic 404 by default. - NotFound(Handler) + // NotFound sets the handlers that are called when a no route matches a request. Throws a basic 404 by default. + NotFound(...Handler) // Handle is the entry point for routing. This is used as a martini.Handler Handle(http.ResponseWriter, *http.Request, Context) } type router struct { - routes []*route - notFound Handler + routes []*route + notFounds []Handler } // NewRouter creates a new Router instance. func NewRouter() Router { - return &router{notFound: http.NotFound} + return &router{notFounds: []Handler{http.NotFound}} } func (r *router) Get(pattern string, h ...Handler) Route { @@ -91,14 +91,13 @@ func (r *router) Handle(res http.ResponseWriter, req *http.Request, context Cont } // no routes exist, 404 - _, err := context.Invoke(r.notFound) - if err != nil { - panic(err) - } + c := &routeContext{context, 0, r.notFounds} + context.MapTo(c, (*Context)(nil)) + c.run() } -func (r *router) NotFound(handler Handler) { - r.notFound = handler +func (r *router) NotFound(handler ...Handler) { + r.notFounds = handler } func (r *router) addRoute(method string, pattern string, handlers []Handler) *route { diff --git a/router_test.go b/router_test.go index 2ad9c89..866c754 100644 --- a/router_test.go +++ b/router_test.go @@ -225,6 +225,84 @@ func Test_NotFound(t *testing.T) { expect(t, recorder.Body.String(), "Nope\n") } +func Test_NotFoundAsHandler(t *testing.T) { + router := NewRouter() + recorder := httptest.NewRecorder() + + req, _ := http.NewRequest("GET", "http://localhost:3000/foo", nil) + context := New().createContext(recorder, req) + + router.NotFound(func() string { + return "not found" + }) + + router.Handle(recorder, req, context) + expect(t, recorder.Code, http.StatusOK) + expect(t, recorder.Body.String(), "not found") + + recorder = httptest.NewRecorder() + + context = New().createContext(recorder, req) + + router.NotFound(func() (int, string) { + return 404, "not found" + }) + + router.Handle(recorder, req, context) + expect(t, recorder.Code, http.StatusNotFound) + expect(t, recorder.Body.String(), "not found") + + recorder = httptest.NewRecorder() + + context = New().createContext(recorder, req) + + router.NotFound(func() (int, string) { + return 200, "" + }) + + router.Handle(recorder, req, context) + expect(t, recorder.Code, http.StatusOK) + expect(t, recorder.Body.String(), "") +} + +func Test_NotFoundStacking(t *testing.T) { + router := NewRouter() + recorder := httptest.NewRecorder() + + req, err := http.NewRequest("GET", "http://localhost:3000/foo", nil) + if err != nil { + t.Error(err) + } + context := New().createContext(recorder, req) + + result := "" + + f1 := func() { + result += "foo" + } + + f2 := func(c Context) { + result += "bar" + c.Next() + result += "bing" + } + + f3 := func() string { + result += "bat" + return "Not Found" + } + + f4 := func() { + result += "baz" + } + + router.NotFound(f1, f2, f3, f4) + + router.Handle(recorder, req, context) + expect(t, result, "foobarbatbing") + expect(t, recorder.Body.String(), "Not Found") +} + func Test_Any(t *testing.T) { router := NewRouter() router.Any("/foo", func(res http.ResponseWriter) {