diff --git a/node.go b/node.go index 12b650a..5f99121 100644 --- a/node.go +++ b/node.go @@ -112,9 +112,15 @@ func (n *node) addPart(part string) *node { func (n *node) findRoute(meth, path string) (*node, routeHandler, int) { if path == "" { - return n, n.handler(meth), 0 + if n.handlerMap != nil { + return n, n.handlerMap.Get(meth), 0 + } + return nil, routeHandler{}, 0 } + return n._findRoute(meth, path) +} +func (n *node) _findRoute(meth, path string) (*node, routeHandler, int) { var found *node if firstChar := path[0]; firstChar >= n.index.minChar && firstChar <= n.index.maxChar { @@ -122,16 +128,16 @@ func (n *node) findRoute(meth, path string) (*node, routeHandler, int) { childNode := n.nodes[i-1] if childNode.part == path { - if handler := childNode.handler(meth); handler.fn != nil { - return childNode, handler, 0 - } if childNode.handlerMap != nil { + if handler := childNode.handlerMap.Get(meth); handler.fn != nil { + return childNode, handler, 0 + } found = childNode } } else { partLen := len(childNode.part) if strings.HasPrefix(path, childNode.part) { - node, handler, wildcardLen := childNode.findRoute(meth, path[partLen:]) + node, handler, wildcardLen := childNode._findRoute(meth, path[partLen:]) if handler.fn != nil { return node, handler, wildcardLen } @@ -145,22 +151,22 @@ func (n *node) findRoute(meth, path string) (*node, routeHandler, int) { if n.colon != nil { if i := strings.IndexByte(path, '/'); i > 0 { - node, handler, wildcardLen := n.colon.findRoute(meth, path[i:]) + node, handler, wildcardLen := n.colon._findRoute(meth, path[i:]) if handler.fn != nil { return node, handler, wildcardLen } - } else { - if handler := n.colon.handler(meth); handler.fn != nil { + } else if n.colon.handlerMap != nil { + if handler := n.colon.handlerMap.Get(meth); handler.fn != nil { return n.colon, handler, 0 } - if found == nil && n.colon.handlerMap != nil { + if found == nil { found = n.colon } } } if n.isWildcard && n.handlerMap != nil { - if handler := n.handler(meth); handler.fn != nil { + if handler := n.handlerMap.Get(meth); handler.fn != nil { if handler.slash { return n, handler, len(path) - 1 } @@ -218,13 +224,6 @@ func (n *node) setHandler(verb string, handler routeHandler) { n.handlerMap.Set(verb, handler) } -func (n *node) handler(verb string) routeHandler { - if n.handlerMap == nil { - return routeHandler{} - } - return n.handlerMap.Get(verb) -} - //------------------------------------------------------------------------------ type routeParser struct { diff --git a/router.go b/router.go index ae30dd9..64ed87a 100644 --- a/router.go +++ b/router.go @@ -16,7 +16,10 @@ type Router struct { func New(opts ...Option) *Router { r := &Router{ - tree: node{route: "/", part: "/"}, + tree: node{ + route: "/", + part: "/", + }, } r.Group.router = r @@ -63,9 +66,10 @@ func (r *Router) lookup(w http.ResponseWriter, req *http.Request) (HandlerFunc, node, handler, wildcardLen := r.tree.findRoute(req.Method, path[1:]) if node == nil { // Path was not found. Try cleaning it up and search again. - cleanPath := CleanPath(unescapedPath) - if found, _, _ := r.tree.findRoute(req.Method, cleanPath[1:]); found != nil { - return redirectHandler(cleanPath), Params{} + if cleanPath := CleanPath(unescapedPath); cleanPath != path { + if found, _, _ := r.tree.findRoute(req.Method, cleanPath[1:]); found != nil { + return redirectHandler(cleanPath), Params{} + } } return r.notFoundHandler, Params{} diff --git a/router_test.go b/router_test.go index bd3dcae..d3fe155 100644 --- a/router_test.go +++ b/router_test.go @@ -700,10 +700,19 @@ func TestRoutesWithCommonPrefix(t *testing.T) { router.GET("/campaigns", simpleHandler) router.GET("/causes", simpleHandler) - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/ca", nil) - router.ServeHTTP(w, req) - require.Equal(t, http.StatusNotFound, w.Code) + { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/ca", nil) + router.ServeHTTP(w, req) + require.Equal(t, http.StatusNotFound, w.Code) + } + + { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + router.ServeHTTP(w, req) + require.Equal(t, http.StatusNotFound, w.Code) + } } func TestNotAllowedMiddleware(t *testing.T) {