Skip to content

Commit

Permalink
refactor: use http.Request.URL.Path to match
Browse files Browse the repository at this point in the history
  • Loading branch information
aofei committed Nov 22, 2021
1 parent b991891 commit dba0341
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 47 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ R2 is built for people who:
* Based on [radix tree](https://en.wikipedia.org/wiki/Radix_tree)
* Sub-router support
* Path parameter support
* Path auto-correction support
* No [`http.Handler`](https://pkg.go.dev/net/http#Handler) variant
* Middleware support
* Zero third-party dependencies
Expand Down
55 changes: 17 additions & 38 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,7 @@ func (r *Router) insertRoute(
}

// Handler returns a matched `http.Handler` for the `req` along with a possible
// revision of the `req`. It takes the `req.Method` and absolute path from the
// `req.RequestURI` to match, so the `req.RequestURI` is assumed to be in the
// origin-form (see RFC 7230, section 5.3.1).
// revision of the `req`.
//
// The returned `http.Handler` is always non-nil.
//
Expand All @@ -465,49 +463,30 @@ func (r *Router) Handler(req *http.Request) (http.Handler, *http.Request) {
return r.Parent.Handler(req)
}

if r.routeTree == nil ||
req.RequestURI == "" ||
req.RequestURI[0] != '/' {
if r.routeTree == nil {
return r.notFoundHandler(), req
}

path := req.RequestURI
for i := 1; i < len(path); i++ {
if path[i] == '?' {
path = path[:i]
break
}
}

var (
s = path // Search
si int // Search index
sl int // Search length
pl int // Prefix length
ll int // LCP length
ml int // Minimum length of the `sl` and `pl`
cn = r.routeTree // Current node
sn *routeNode // Saved node
fnt routeNodeType // From node type
nnt routeNodeType // Next node type
ppi int // Path parameter index
ppvs []string // Path parameter values
i int // Index
h http.Handler // Handler
s = req.URL.Path // Search
si int // Search index
sl int // Search length
pl int // Prefix length
ll int // LCP length
ml int // Minimum length of the `sl` and `pl`
cn = r.routeTree // Current node
sn *routeNode // Saved node
fnt routeNodeType // From node type
nnt routeNodeType // Next node type
ppi int // Path parameter index
ppvs []string // Path parameter values
i int // Index
h http.Handler // Handler
)

// Node search order: static > parameter > wildcard parameter.
OuterLoop:
for {
// Skip continuous '/'.
if s != "" && s[0] == '/' {
i, sl = 1, len(s)
for ; i < sl && s[i] == '/'; i++ {
}

s = s[i-1:]
}

if cn.typ == staticRouteNode {
sl, pl = len(s), len(cn.prefix)
if sl < pl {
Expand Down Expand Up @@ -671,7 +650,7 @@ OuterLoop:
si -= len(ppvs[ppi])
}

s = path[si:]
s = req.URL.Path[si:]
}

if cn.typ < wildcardParamRouteNode {
Expand Down
14 changes: 6 additions & 8 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func TestRouterHandler(t *testing.T) {
r = &Router{}
r.Handle("", "/", h, mf)
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.RequestURI = ""
req.URL.Path = ""
rec = httptest.NewRecorder()
mh, req = r.Handler(req)
mh.ServeHTTP(rec, req)
Expand Down Expand Up @@ -328,11 +328,11 @@ func TestRouterHandler_static(t *testing.T) {
mh, req = r.Handler(req)
mh.ServeHTTP(rec, req)
recr = rec.Result()
if want := http.StatusOK; recr.StatusCode != want {
if want := http.StatusNotFound; recr.StatusCode != want {
t.Errorf("got %d, want %d", recr.StatusCode, want)
} else if b, err := ioutil.ReadAll(recr.Body); err != nil {
t.Fatalf("unexpected error %q", err)
} else if want := "GET /"; string(b) != want {
} else if want := "Not Found\n"; string(b) != want {
t.Errorf("got %q, want %q", b, want)
}

Expand Down Expand Up @@ -457,14 +457,12 @@ func TestRouterHandler_param(t *testing.T) {
mh, req = r.Handler(req)
mh.ServeHTTP(rec, req)
recr = rec.Result()
if want := http.StatusOK; recr.StatusCode != want {
if want := http.StatusNotFound; recr.StatusCode != want {
t.Errorf("got %d, want %d", recr.StatusCode, want)
} else if b, err := ioutil.ReadAll(recr.Body); err != nil {
t.Fatalf("unexpected error %q", err)
} else if want := "GET /:foobar"; string(b) != want {
} else if want := "Not Found\n"; string(b) != want {
t.Errorf("got %q, want %q", b, want)
} else if got, want := PathParam(req, "foobar"), ""; got != want {
t.Errorf("got %q, want %q", got, want)
}

req = httptest.NewRequest(http.MethodHead, "/", nil)
Expand Down Expand Up @@ -607,7 +605,7 @@ func TestRouterHandler_wildcardParam(t *testing.T) {
t.Fatalf("unexpected error %q", err)
} else if want := "GET /*"; string(b) != want {
t.Errorf("got %q, want %q", b, want)
} else if got, want := PathParam(req, "*"), ""; got != want {
} else if got, want := PathParam(req, "*"), "/"; got != want {
t.Errorf("got %q, want %q", got, want)
}

Expand Down

0 comments on commit dba0341

Please sign in to comment.