Skip to content

Commit

Permalink
Added MethodsFor to Routes helper
Browse files Browse the repository at this point in the history
  • Loading branch information
shirro committed Feb 9, 2014
1 parent b34eed6 commit 26855a7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
27 changes: 25 additions & 2 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,12 @@ func (r *router) Any(pattern string, h ...Handler) Route {
}

func (r *router) Handle(res http.ResponseWriter, req *http.Request, context Context) {
context.MapTo(routes{r}, (*Routes)(nil))
for _, route := range r.routes {
ok, vals := route.Match(req.Method, req.URL.Path)
if ok {
params := Params(vals)
context.Map(params)
r := routes{r}
context.MapTo(r, (*Routes)(nil))
route.Handle(context, res)
return
}
Expand Down Expand Up @@ -216,6 +215,8 @@ func (r *route) Name(name string) {
type Routes interface {
// URLFor returns a rendered URL for the given route. Optional params can be passed to fulfill named parameters in the route.
URLFor(name string, params ...interface{}) string
// MethodsFor returns an array of methods available for the path
MethodsFor(path string) []string
}

type routes struct {
Expand Down Expand Up @@ -247,6 +248,28 @@ func (r routes) URLFor(name string, params ...interface{}) string {
return route.URLWith(args)
}

func hasMethod(methods []string, method string) bool {
for _, v := range methods {
if v == method {
return true
}
}
return false
}

// MethodsFor returns all methods available for path
func (r routes) MethodsFor(path string) []string {
methods := []string{}
for _, route := range r.router.routes {
if route.regex.MatchString(path) {
if !hasMethod(methods, route.method) {
methods = append(methods, route.method)
}
}
}
return methods
}

type routeContext struct {
Context
index int
Expand Down
28 changes: 28 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package martini
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)

Expand Down Expand Up @@ -217,6 +218,33 @@ func Test_RouteMatching(t *testing.T) {
}
}

func Test_MethodsFor(t *testing.T) {
router := NewRouter()
recorder := httptest.NewRecorder()

req, _ := http.NewRequest("POST", "http://localhost:3000/foo", nil)
context := New().createContext(recorder, req)
router.Post("/foo/bar", func() {
})

router.Get("/foo", func() {
})

router.Put("/foo", func() {
})

router.NotFound(func(routes Routes, w http.ResponseWriter, r *http.Request) {
methods := routes.MethodsFor(r.URL.Path)
if len(methods) != 0 {
w.Header().Set("Allow", strings.Join(methods, ","))
w.WriteHeader(http.StatusMethodNotAllowed)
}
})
router.Handle(recorder, req, context)
expect(t, recorder.Code, http.StatusMethodNotAllowed)
expect(t, recorder.Header().Get("Allow"), "GET,PUT")
}

func Test_NotFound(t *testing.T) {
router := NewRouter()
recorder := httptest.NewRecorder()
Expand Down

0 comments on commit 26855a7

Please sign in to comment.