diff --git a/router.go b/router.go index d5ad535..bee49b9 100644 --- a/router.go +++ b/router.go @@ -80,13 +80,12 @@ func (r *router) Any(pattern string, h ...Handler) Route { } func (r *router) Handle(res http.ResponseWriter, req *http.Request, context Context) { + context.MapTo(routes{r}, (*Routes)(nil)) for _, route := range r.routes { ok, vals := route.Match(req.Method, req.URL.Path) if ok { params := Params(vals) context.Map(params) - r := routes{r} - context.MapTo(r, (*Routes)(nil)) route.Handle(context, res) return } @@ -216,6 +215,8 @@ func (r *route) Name(name string) { type Routes interface { // URLFor returns a rendered URL for the given route. Optional params can be passed to fulfill named parameters in the route. URLFor(name string, params ...interface{}) string + // MethodsFor returns an array of methods available for the path + MethodsFor(path string) []string } type routes struct { @@ -247,6 +248,28 @@ func (r routes) URLFor(name string, params ...interface{}) string { return route.URLWith(args) } +func hasMethod(methods []string, method string) bool { + for _, v := range methods { + if v == method { + return true + } + } + return false +} + +// MethodsFor returns all methods available for path +func (r routes) MethodsFor(path string) []string { + methods := []string{} + for _, route := range r.router.routes { + if route.regex.MatchString(path) { + if !hasMethod(methods, route.method) { + methods = append(methods, route.method) + } + } + } + return methods +} + type routeContext struct { Context index int diff --git a/router_test.go b/router_test.go index 85d45d5..9d3ecea 100644 --- a/router_test.go +++ b/router_test.go @@ -3,6 +3,7 @@ package martini import ( "net/http" "net/http/httptest" + "strings" "testing" ) @@ -217,6 +218,33 @@ func Test_RouteMatching(t *testing.T) { } } +func Test_MethodsFor(t *testing.T) { + router := NewRouter() + recorder := httptest.NewRecorder() + + req, _ := http.NewRequest("POST", "http://localhost:3000/foo", nil) + context := New().createContext(recorder, req) + router.Post("/foo/bar", func() { + }) + + router.Get("/foo", func() { + }) + + router.Put("/foo", func() { + }) + + router.NotFound(func(routes Routes, w http.ResponseWriter, r *http.Request) { + methods := routes.MethodsFor(r.URL.Path) + if len(methods) != 0 { + w.Header().Set("Allow", strings.Join(methods, ",")) + w.WriteHeader(http.StatusMethodNotAllowed) + } + }) + router.Handle(recorder, req, context) + expect(t, recorder.Code, http.StatusMethodNotAllowed) + expect(t, recorder.Header().Get("Allow"), "GET,PUT") +} + func Test_NotFound(t *testing.T) { router := NewRouter() recorder := httptest.NewRecorder()