Skip to content

Commit

Permalink
fix: properly handle empty root node
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Nov 16, 2021
1 parent fc827bf commit b37ad45
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 25 deletions.
33 changes: 16 additions & 17 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,26 +112,32 @@ func (n *node) addPart(part string) *node {

func (n *node) findRoute(meth, path string) (*node, routeHandler, int) {
if path == "" {
return n, n.handler(meth), 0
if n.handlerMap != nil {
return n, n.handlerMap.Get(meth), 0
}
return nil, routeHandler{}, 0
}
return n._findRoute(meth, path)
}

func (n *node) _findRoute(meth, path string) (*node, routeHandler, int) {
var found *node

if firstChar := path[0]; firstChar >= n.index.minChar && firstChar <= n.index.maxChar {
if i := n.index.table[firstChar-n.index.minChar]; i != 0 {
childNode := n.nodes[i-1]

if childNode.part == path {
if handler := childNode.handler(meth); handler.fn != nil {
return childNode, handler, 0
}
if childNode.handlerMap != nil {
if handler := childNode.handlerMap.Get(meth); handler.fn != nil {
return childNode, handler, 0
}
found = childNode
}
} else {
partLen := len(childNode.part)
if strings.HasPrefix(path, childNode.part) {
node, handler, wildcardLen := childNode.findRoute(meth, path[partLen:])
node, handler, wildcardLen := childNode._findRoute(meth, path[partLen:])
if handler.fn != nil {
return node, handler, wildcardLen
}
Expand All @@ -145,22 +151,22 @@ func (n *node) findRoute(meth, path string) (*node, routeHandler, int) {

if n.colon != nil {
if i := strings.IndexByte(path, '/'); i > 0 {
node, handler, wildcardLen := n.colon.findRoute(meth, path[i:])
node, handler, wildcardLen := n.colon._findRoute(meth, path[i:])
if handler.fn != nil {
return node, handler, wildcardLen
}
} else {
if handler := n.colon.handler(meth); handler.fn != nil {
} else if n.colon.handlerMap != nil {
if handler := n.colon.handlerMap.Get(meth); handler.fn != nil {
return n.colon, handler, 0
}
if found == nil && n.colon.handlerMap != nil {
if found == nil {
found = n.colon
}
}
}

if n.isWildcard && n.handlerMap != nil {
if handler := n.handler(meth); handler.fn != nil {
if handler := n.handlerMap.Get(meth); handler.fn != nil {
if handler.slash {
return n, handler, len(path) - 1
}
Expand Down Expand Up @@ -218,13 +224,6 @@ func (n *node) setHandler(verb string, handler routeHandler) {
n.handlerMap.Set(verb, handler)
}

func (n *node) handler(verb string) routeHandler {
if n.handlerMap == nil {
return routeHandler{}
}
return n.handlerMap.Get(verb)
}

//------------------------------------------------------------------------------

type routeParser struct {
Expand Down
12 changes: 8 additions & 4 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ type Router struct {

func New(opts ...Option) *Router {
r := &Router{
tree: node{route: "/", part: "/"},
tree: node{
route: "/",
part: "/",
},
}

r.Group.router = r
Expand Down Expand Up @@ -63,9 +66,10 @@ func (r *Router) lookup(w http.ResponseWriter, req *http.Request) (HandlerFunc,
node, handler, wildcardLen := r.tree.findRoute(req.Method, path[1:])
if node == nil {
// Path was not found. Try cleaning it up and search again.
cleanPath := CleanPath(unescapedPath)
if found, _, _ := r.tree.findRoute(req.Method, cleanPath[1:]); found != nil {
return redirectHandler(cleanPath), Params{}
if cleanPath := CleanPath(unescapedPath); cleanPath != path {
if found, _, _ := r.tree.findRoute(req.Method, cleanPath[1:]); found != nil {
return redirectHandler(cleanPath), Params{}
}
}

return r.notFoundHandler, Params{}
Expand Down
17 changes: 13 additions & 4 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,10 +700,19 @@ func TestRoutesWithCommonPrefix(t *testing.T) {
router.GET("/campaigns", simpleHandler)
router.GET("/causes", simpleHandler)

w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/ca", nil)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code)
{
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/ca", nil)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code)
}

{
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code)
}
}

func TestNotAllowedMiddleware(t *testing.T) {
Expand Down

0 comments on commit b37ad45

Please sign in to comment.