Skip to content

Commit

Permalink
support cors in rest server
Browse files Browse the repository at this point in the history
  • Loading branch information
kevwan committed Oct 21, 2020
1 parent 1c1e4bc commit fe0d068
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 8 deletions.
2 changes: 1 addition & 1 deletion example/http/demo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func main() {
Port: *port,
Timeout: *timeout,
MaxConns: 500,
})
}, rest.WithNotAllowedHandler(rest.CorsHandler()))
defer engine.Stop()

engine.Use(first)
Expand Down
29 changes: 29 additions & 0 deletions rest/handlers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package rest

import (
"net/http"
"strings"
)

const (
allowOrigin = "Access-Control-Allow-Origin"
allOrigin = "*"
allowMethods = "Access-Control-Allow-Methods"
allowHeaders = "Access-Control-Allow-Headers"
headers = "Content-Type, Content-Length, Origin"
methods = "GET, HEAD, POST, PATCH, PUT, DELETE"
separator = ", "
)

func CorsHandler(origins ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if len(origins) > 0 {
w.Header().Set(allowOrigin, strings.Join(origins, separator))
} else {
w.Header().Set(allowOrigin, allOrigin)
}
w.Header().Set(allowMethods, methods)
w.Header().Set(allowHeaders, headers)
w.WriteHeader(http.StatusNoContent)
})
}
27 changes: 27 additions & 0 deletions rest/handlers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package rest

import (
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestCorsHandler(t *testing.T) {
w := httptest.NewRecorder()
handler := CorsHandler()
handler.ServeHTTP(w, nil)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
assert.Equal(t, allOrigin, w.Header().Get(allowOrigin))
}

func TestCorsHandlerWithOrigins(t *testing.T) {
origins := []string{"local", "remote"}
w := httptest.NewRecorder()
handler := CorsHandler(origins...)
handler.ServeHTTP(w, nil)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
assert.Equal(t, strings.Join(origins, separator), w.Header().Get(allowOrigin))
}
1 change: 1 addition & 0 deletions rest/httpx/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ type Router interface {
http.Handler
Handle(method string, path string, handler http.Handler) error
SetNotFoundHandler(handler http.Handler)
SetNotAllowedHandler(handler http.Handler)
}
21 changes: 16 additions & 5 deletions rest/router/patrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ var (
)

type patRouter struct {
trees map[string]*search.Tree
notFound http.Handler
trees map[string]*search.Tree
notFound http.Handler
notAllowed http.Handler
}

func NewRouter() httpx.Router {
Expand Down Expand Up @@ -63,18 +64,28 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok {
allow, ok := pr.methodNotAllowed(r.Method, reqPath)
if !ok {
pr.handleNotFound(w, r)
return
}

if pr.notAllowed != nil {
pr.notAllowed.ServeHTTP(w, r)
} else {
w.Header().Set(allowHeader, allow)
w.WriteHeader(http.StatusMethodNotAllowed)
} else {
pr.handleNotFound(w, r)
}
}

func (pr *patRouter) SetNotFoundHandler(handler http.Handler) {
pr.notFound = handler
}

func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) {
pr.notAllowed = handler
}

func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
if pr.notFound != nil {
pr.notFound.ServeHTTP(w, r)
Expand Down
19 changes: 18 additions & 1 deletion rest/router/patrouter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,30 @@ func TestPatRouterNotFound(t *testing.T) {
router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
notFound = true
}))
router.Handle(http.MethodGet, "/a/b", nil)
err := router.Handle(http.MethodGet, "/a/b",
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
assert.Nil(t, err)
r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
w := new(mockedResponseWriter)
router.ServeHTTP(w, r)
assert.True(t, notFound)
}

func TestPatRouterNotAllowed(t *testing.T) {
var notAllowed bool
router := NewRouter()
router.SetNotAllowedHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
notAllowed = true
}))
err := router.Handle(http.MethodGet, "/a/b",
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
assert.Nil(t, err)
r, _ := http.NewRequest(http.MethodPost, "/a/b", nil)
w := new(mockedResponseWriter)
router.ServeHTTP(w, r)
assert.True(t, notAllowed)
}

func TestPatRouter(t *testing.T) {
tests := []struct {
method string
Expand Down
18 changes: 18 additions & 0 deletions rest/server.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package rest

import (
"errors"
"log"
"net/http"

"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/rest/handler"
"github.com/tal-tech/go-zero/rest/httpx"
"github.com/tal-tech/go-zero/rest/router"
)

type (
Expand All @@ -32,6 +34,10 @@ func MustNewServer(c RestConf, opts ...RunOption) *Server {
}

func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
if len(opts) > 1 {
return nil, errors.New("only one RunOption is allowed")
}

if err := c.SetUp(); err != nil {
return nil, err
}
Expand Down Expand Up @@ -125,6 +131,18 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route {
return routes
}

func WithNotFoundHandler(handler http.Handler) RunOption {
rt := router.NewRouter()
rt.SetNotFoundHandler(handler)
return WithRouter(rt)
}

func WithNotAllowedHandler(handler http.Handler) RunOption {
rt := router.NewRouter()
rt.SetNotAllowedHandler(handler)
return WithRouter(rt)
}

func WithPriority() RouteOption {
return func(r *featuredRoutes) {
r.priority = true
Expand Down
13 changes: 12 additions & 1 deletion rest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ import (
"github.com/tal-tech/go-zero/rest/router"
)

func TestNewServer(t *testing.T) {
_, err := NewServer(RestConf{}, WithNotFoundHandler(nil), WithNotAllowedHandler(nil))
assert.NotNil(t, err)
}

func TestWithMiddleware(t *testing.T) {
m := make(map[string]string)
router := router.NewRouter()
Expand Down Expand Up @@ -69,7 +74,7 @@ func TestWithMiddleware(t *testing.T) {
}, m)
}

func TestMultiMiddleware(t *testing.T) {
func TestMultiMiddlewares(t *testing.T) {
m := make(map[string]string)
router := router.NewRouter()
handler := func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -140,3 +145,9 @@ func TestMultiMiddleware(t *testing.T) {
"whatever": "200000200000",
}, m)
}

func TestWithPriority(t *testing.T) {
var fr featuredRoutes
WithPriority()(&fr)
assert.True(t, fr.priority)
}

0 comments on commit fe0d068

Please sign in to comment.