Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package mux

import "net/http"
import (
"net/http"
"strings"
)

// MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler.
// Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed
Expand Down Expand Up @@ -28,3 +31,42 @@ func (r *Router) Use(mwf ...MiddlewareFunc) {
func (r *Router) useInterface(mw middleware) {
r.middlewares = append(r.middlewares, mw)
}

// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
// on a request, by matching routes based only on paths. It also handles
// OPTIONS requests, by settings Access-Control-Allow-Methods, and then
// returning without calling the next http handler.
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var allMethods []string

err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
for _, m := range route.matchers {
if _, ok := m.(*routeRegexp); ok {
if m.Match(req, &RouteMatch{}) {
methods, err := route.GetMethods()
if err != nil {
return err
}

allMethods = append(allMethods, methods...)
}
break
}
}
return nil
})

if err == nil {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ","))

if req.Method == "OPTIONS" {
return
}
}

next.ServeHTTP(w, req)
})
}
}
41 changes: 41 additions & 0 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mux
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
)

Expand Down Expand Up @@ -334,3 +335,43 @@ func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
t.Fatal("Middleware was called for a method mismatch")
}
}

func TestCORSMethodMiddleware(t *testing.T) {
router := NewRouter()

cases := []struct {
path string
response string
method string
testURL string
expectedAllowedMethods string
}{
{"/g/{o}", "a", "POST", "/g/asdf", "POST,PUT,GET,OPTIONS"},
{"/g/{o}", "b", "PUT", "/g/bla", "POST,PUT,GET,OPTIONS"},
{"/g/{o}", "c", "GET", "/g/orilla", "POST,PUT,GET,OPTIONS"},
{"/g", "d", "POST", "/g", "POST,OPTIONS"},
}

for _, tt := range cases {
router.HandleFunc(tt.path, stringHandler(tt.response)).Methods(tt.method)
}

router.Use(CORSMethodMiddleware(router))

for _, tt := range cases {
rr := httptest.NewRecorder()
req := newRequest(tt.method, tt.testURL)

router.ServeHTTP(rr, req)

if rr.Body.String() != tt.response {
t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String())
}

allowedMethods := rr.HeaderMap.Get("Access-Control-Allow-Methods")

if allowedMethods != tt.expectedAllowedMethods {
t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods)
}
}
}
8 changes: 8 additions & 0 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2315,6 +2315,14 @@ func stringMapEqual(m1, m2 map[string]string) bool {
return true
}

// stringHandler returns a handler func that writes a message 's' to the
// http.ResponseWriter.
func stringHandler(s string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(s))
}
}

// newRequest is a helper function to create a new request with a method and url.
// The request returned is a 'server' request as opposed to a 'client' one through
// simulated write onto the wire and read off of the wire.
Expand Down