From 758eb64354aa27cda6c1b026822c788cc755f06f Mon Sep 17 00:00:00 2001 From: Joe Wilner Date: Fri, 7 Dec 2018 10:48:26 -0500 Subject: [PATCH] Improve subroute configuration propagation #422 * Pull out common shared `routeConf` so that config is pushed on to child routers and routes. * Removes obsolete usages of `parentRoute` * Add tests defining compositional behavior * Exercise `copyRouteConf` for posterity --- mux.go | 114 ++++++++------ mux_test.go | 440 +++++++++++++++++++++++++++++++++++++++++++++++++--- regexp.go | 2 +- route.go | 125 ++++----------- 4 files changed, 511 insertions(+), 170 deletions(-) diff --git a/mux.go b/mux.go index 4bbafa51..50ac1184 100644 --- a/mux.go +++ b/mux.go @@ -50,24 +50,77 @@ type Router struct { // Configurable Handler to be used when the request method does not match the route. MethodNotAllowedHandler http.Handler - // Parent route, if this is a subrouter. - parent parentRoute // Routes to be matched, in order. routes []*Route + // Routes by name for URL building. namedRoutes map[string]*Route - // See Router.StrictSlash(). This defines the flag for new routes. - strictSlash bool - // See Router.SkipClean(). This defines the flag for new routes. - skipClean bool + // If true, do not clear the request context after handling the request. // This has no effect when go1.7+ is used, since the context is stored // on the request itself. KeepContext bool - // see Router.UseEncodedPath(). This defines a flag for all routes. - useEncodedPath bool + // Slice of middlewares to be called after a match is found middlewares []middleware + + // configuration shared with `Route` + routeConf +} + +// common route configuration shared between `Router` and `Route` +type routeConf struct { + // If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to" + useEncodedPath bool + + // If true, when the path pattern is "/path/", accessing "/path" will + // redirect to the former and vice versa. + strictSlash bool + + // If true, when the path pattern is "/path//to", accessing "/path//to" + // will not redirect + skipClean bool + + // Manager for the variables from host and path. + regexp routeRegexpGroup + + // List of matchers. + matchers []matcher + + // The scheme used when building URLs. + buildScheme string + + buildVarsFunc BuildVarsFunc +} + +// returns an effective deep copy of `routeConf` +func copyRouteConf(r routeConf) routeConf { + c := r + + if r.regexp.path != nil { + c.regexp.path = copyRouteRegexp(r.regexp.path) + } + + if r.regexp.host != nil { + c.regexp.host = copyRouteRegexp(r.regexp.host) + } + + c.regexp.queries = make([]*routeRegexp, 0, len(r.regexp.queries)) + for _, q := range r.regexp.queries { + c.regexp.queries = append(c.regexp.queries, copyRouteRegexp(q)) + } + + c.matchers = make([]matcher, 0, len(r.matchers)) + for _, m := range r.matchers { + c.matchers = append(c.matchers, m) + } + + return c +} + +func copyRouteRegexp(r *routeRegexp) *routeRegexp { + c := *r + return &c } // Match attempts to match the given request against the router's registered routes. @@ -164,13 +217,13 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Get returns a route registered with the given name. func (r *Router) Get(name string) *Route { - return r.getNamedRoutes()[name] + return r.namedRoutes[name] } // GetRoute returns a route registered with the given name. This method // was renamed to Get() and remains here for backwards compatibility. func (r *Router) GetRoute(name string) *Route { - return r.getNamedRoutes()[name] + return r.namedRoutes[name] } // StrictSlash defines the trailing slash behavior for new routes. The initial @@ -221,51 +274,14 @@ func (r *Router) UseEncodedPath() *Router { return r } -// ---------------------------------------------------------------------------- -// parentRoute -// ---------------------------------------------------------------------------- - -func (r *Router) getBuildScheme() string { - if r.parent != nil { - return r.parent.getBuildScheme() - } - return "" -} - -// getNamedRoutes returns the map where named routes are registered. -func (r *Router) getNamedRoutes() map[string]*Route { - if r.namedRoutes == nil { - if r.parent != nil { - r.namedRoutes = r.parent.getNamedRoutes() - } else { - r.namedRoutes = make(map[string]*Route) - } - } - return r.namedRoutes -} - -// getRegexpGroup returns regexp definitions from the parent route, if any. -func (r *Router) getRegexpGroup() *routeRegexpGroup { - if r.parent != nil { - return r.parent.getRegexpGroup() - } - return nil -} - -func (r *Router) buildVars(m map[string]string) map[string]string { - if r.parent != nil { - m = r.parent.buildVars(m) - } - return m -} - // ---------------------------------------------------------------------------- // Route factories // ---------------------------------------------------------------------------- // NewRoute registers an empty route. func (r *Router) NewRoute() *Route { - route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath} + // initialize a route with a copy of the parent router's configuration + route := &Route{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes} r.routes = append(r.routes, route) return route } diff --git a/mux_test.go b/mux_test.go index 5d4027e4..519aa92c 100644 --- a/mux_test.go +++ b/mux_test.go @@ -48,15 +48,6 @@ type routeTest struct { } func TestHost(t *testing.T) { - // newRequestHost a new request with a method, url, and host header - newRequestHost := func(method, url, host string) *http.Request { - req, err := http.NewRequest(method, url, nil) - if err != nil { - panic(err) - } - req.Host = host - return req - } tests := []routeTest{ { @@ -1193,7 +1184,6 @@ func TestSubRouter(t *testing.T) { subrouter3 := new(Route).PathPrefix("/foo").Subrouter() subrouter4 := new(Route).PathPrefix("/foo/bar").Subrouter() subrouter5 := new(Route).PathPrefix("/{category}").Subrouter() - tests := []routeTest{ { route: subrouter1.Path("/{v2:[a-z]+}"), @@ -1288,6 +1278,106 @@ func TestSubRouter(t *testing.T) { pathTemplate: `/{category}`, shouldMatch: true, }, + { + title: "Mismatch method specified on parent route", + route: new(Route).Methods("POST").PathPrefix("/foo").Subrouter().Path("/"), + request: newRequest("GET", "http://localhost/foo/"), + vars: map[string]string{}, + host: "", + path: "/foo/", + pathTemplate: `/foo/`, + shouldMatch: false, + }, + { + title: "Match method specified on parent route", + route: new(Route).Methods("POST").PathPrefix("/foo").Subrouter().Path("/"), + request: newRequest("POST", "http://localhost/foo/"), + vars: map[string]string{}, + host: "", + path: "/foo/", + pathTemplate: `/foo/`, + shouldMatch: true, + }, + { + title: "Mismatch scheme specified on parent route", + route: new(Route).Schemes("https").Subrouter().PathPrefix("/"), + request: newRequest("GET", "http://localhost/"), + vars: map[string]string{}, + host: "", + path: "/", + pathTemplate: `/`, + shouldMatch: false, + }, + { + title: "Match scheme specified on parent route", + route: new(Route).Schemes("http").Subrouter().PathPrefix("/"), + request: newRequest("GET", "http://localhost/"), + vars: map[string]string{}, + host: "", + path: "/", + pathTemplate: `/`, + shouldMatch: true, + }, + { + title: "No match header specified on parent route", + route: new(Route).Headers("X-Forwarded-Proto", "https").Subrouter().PathPrefix("/"), + request: newRequest("GET", "http://localhost/"), + vars: map[string]string{}, + host: "", + path: "/", + pathTemplate: `/`, + shouldMatch: false, + }, + { + title: "Header mismatch value specified on parent route", + route: new(Route).Headers("X-Forwarded-Proto", "https").Subrouter().PathPrefix("/"), + request: newRequestWithHeaders("GET", "http://localhost/", "X-Forwarded-Proto", "http"), + vars: map[string]string{}, + host: "", + path: "/", + pathTemplate: `/`, + shouldMatch: false, + }, + { + title: "Header match value specified on parent route", + route: new(Route).Headers("X-Forwarded-Proto", "https").Subrouter().PathPrefix("/"), + request: newRequestWithHeaders("GET", "http://localhost/", "X-Forwarded-Proto", "https"), + vars: map[string]string{}, + host: "", + path: "/", + pathTemplate: `/`, + shouldMatch: true, + }, + { + title: "Query specified on parent route not present", + route: new(Route).Headers("key", "foobar").Subrouter().PathPrefix("/"), + request: newRequest("GET", "http://localhost/"), + vars: map[string]string{}, + host: "", + path: "/", + pathTemplate: `/`, + shouldMatch: false, + }, + { + title: "Query mismatch value specified on parent route", + route: new(Route).Queries("key", "foobar").Subrouter().PathPrefix("/"), + request: newRequest("GET", "http://localhost/?key=notfoobar"), + vars: map[string]string{}, + host: "", + path: "/", + pathTemplate: `/`, + shouldMatch: false, + }, + { + title: "Query match value specified on subroute", + route: new(Route).Queries("key", "foobar").Subrouter().PathPrefix("/"), + request: newRequest("GET", "http://localhost/?key=foobar"), + vars: map[string]string{}, + host: "", + path: "/", + pathTemplate: `/`, + shouldMatch: true, + }, { title: "Build with scheme on parent router", route: new(Route).Schemes("ftp").Host("google.com").Subrouter().Path("/"), @@ -1512,12 +1602,16 @@ func TestWalkSingleDepth(t *testing.T) { func TestWalkNested(t *testing.T) { router := NewRouter() - g := router.Path("/g").Subrouter() - o := g.PathPrefix("/o").Subrouter() - r := o.PathPrefix("/r").Subrouter() - i := r.PathPrefix("/i").Subrouter() - l1 := i.PathPrefix("/l").Subrouter() - l2 := l1.PathPrefix("/l").Subrouter() + routeSubrouter := func(r *Route) (*Route, *Router) { + return r, r.Subrouter() + } + + gRoute, g := routeSubrouter(router.Path("/g")) + oRoute, o := routeSubrouter(g.PathPrefix("/o")) + rRoute, r := routeSubrouter(o.PathPrefix("/r")) + iRoute, i := routeSubrouter(r.PathPrefix("/i")) + l1Route, l1 := routeSubrouter(i.PathPrefix("/l")) + l2Route, l2 := routeSubrouter(l1.PathPrefix("/l")) l2.Path("/a") testCases := []struct { @@ -1525,12 +1619,12 @@ func TestWalkNested(t *testing.T) { ancestors []*Route }{ {"/g", []*Route{}}, - {"/g/o", []*Route{g.parent.(*Route)}}, - {"/g/o/r", []*Route{g.parent.(*Route), o.parent.(*Route)}}, - {"/g/o/r/i", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route)}}, - {"/g/o/r/i/l", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route)}}, - {"/g/o/r/i/l/l", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route), l1.parent.(*Route)}}, - {"/g/o/r/i/l/l/a", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route), l1.parent.(*Route), l2.parent.(*Route)}}, + {"/g/o", []*Route{gRoute}}, + {"/g/o/r", []*Route{gRoute, oRoute}}, + {"/g/o/r/i", []*Route{gRoute, oRoute, rRoute}}, + {"/g/o/r/i/l", []*Route{gRoute, oRoute, rRoute, iRoute}}, + {"/g/o/r/i/l/l", []*Route{gRoute, oRoute, rRoute, iRoute, l1Route}}, + {"/g/o/r/i/l/l/a", []*Route{gRoute, oRoute, rRoute, iRoute, l1Route, l2Route}}, } idx := 0 @@ -1563,8 +1657,8 @@ func TestWalkSubrouters(t *testing.T) { o.Methods("GET") o.Methods("PUT") - // all 4 routes should be matched, but final 2 routes do not have path templates - paths := []string{"/g", "/g/o", "", ""} + // all 4 routes should be matched + paths := []string{"/g", "/g/o", "/g/o", "/g/o"} idx := 0 err := router.Walk(func(route *Route, router *Router, ancestors []*Route) error { path := paths[idx] @@ -1745,7 +1839,11 @@ func testRoute(t *testing.T, test routeTest) { } } if query != "" { - u, _ := route.URL(mapToPairs(match.Vars)...) + u, err := route.URL(mapToPairs(match.Vars)...) + if err != nil { + t.Errorf("(%v) erred while creating url: %v", test.title, err) + return + } if query != u.RawQuery { t.Errorf("(%v) URL query not equal: expected %v, got %v", test.title, query, u.RawQuery) return @@ -2332,6 +2430,273 @@ func testMethodsSubrouter(t *testing.T, test methodsSubrouterTest) { } } +func TestSubrouterMatching(t *testing.T) { + const ( + none, stdOnly, subOnly uint8 = 0, 1 << 0, 1 << 1 + both = subOnly | stdOnly + ) + + type request struct { + Name string + Request *http.Request + Flags uint8 + } + + cases := []struct { + Name string + Standard, Subrouter func(*Router) + Requests []request + }{ + { + "pathPrefix", + func(r *Router) { + r.PathPrefix("/before").PathPrefix("/after") + }, + func(r *Router) { + r.PathPrefix("/before").Subrouter().PathPrefix("/after") + }, + []request{ + {"no match final path prefix", newRequest("GET", "/after"), none}, + {"no match parent path prefix", newRequest("GET", "/before"), none}, + {"matches append", newRequest("GET", "/before/after"), both}, + {"matches as prefix", newRequest("GET", "/before/after/1234"), both}, + }, + }, + { + "path", + func(r *Router) { + r.Path("/before").Path("/after") + }, + func(r *Router) { + r.Path("/before").Subrouter().Path("/after") + }, + []request{ + {"no match subroute path", newRequest("GET", "/after"), none}, + {"no match parent path", newRequest("GET", "/before"), none}, + {"no match as prefix", newRequest("GET", "/before/after/1234"), none}, + {"no match append", newRequest("GET", "/before/after"), none}, + }, + }, + { + "host", + func(r *Router) { + r.Host("before.com").Host("after.com") + }, + func(r *Router) { + r.Host("before.com").Subrouter().Host("after.com") + }, + []request{ + {"no match before", newRequestHost("GET", "/", "before.com"), none}, + {"no match other", newRequestHost("GET", "/", "other.com"), none}, + {"matches after", newRequestHost("GET", "/", "after.com"), none}, + }, + }, + { + "queries variant keys", + func(r *Router) { + r.Queries("foo", "bar").Queries("cricket", "baseball") + }, + func(r *Router) { + r.Queries("foo", "bar").Subrouter().Queries("cricket", "baseball") + }, + []request{ + {"matches with all", newRequest("GET", "/?foo=bar&cricket=baseball"), both}, + {"matches with more", newRequest("GET", "/?foo=bar&cricket=baseball&something=else"), both}, + {"no match with none", newRequest("GET", "/"), none}, + {"no match with some", newRequest("GET", "/?cricket=baseball"), none}, + }, + }, + { + "queries overlapping keys", + func(r *Router) { + r.Queries("foo", "bar").Queries("foo", "baz") + }, + func(r *Router) { + r.Queries("foo", "bar").Subrouter().Queries("foo", "baz") + }, + []request{ + {"no match old value", newRequest("GET", "/?foo=bar"), none}, + {"no match diff value", newRequest("GET", "/?foo=bak"), none}, + {"no match with none", newRequest("GET", "/"), none}, + {"matches override", newRequest("GET", "/?foo=baz"), none}, + }, + }, + { + "header variant keys", + func(r *Router) { + r.Headers("foo", "bar").Headers("cricket", "baseball") + }, + func(r *Router) { + r.Headers("foo", "bar").Subrouter().Headers("cricket", "baseball") + }, + []request{ + { + "matches with all", + newRequestWithHeaders("GET", "/", "foo", "bar", "cricket", "baseball"), + both, + }, + { + "matches with more", + newRequestWithHeaders("GET", "/", "foo", "bar", "cricket", "baseball", "something", "else"), + both, + }, + {"no match with none", newRequest("GET", "/"), none}, + {"no match with some", newRequestWithHeaders("GET", "/", "cricket", "baseball"), none}, + }, + }, + { + "header overlapping keys", + func(r *Router) { + r.Headers("foo", "bar").Headers("foo", "baz") + }, + func(r *Router) { + r.Headers("foo", "bar").Subrouter().Headers("foo", "baz") + }, + []request{ + {"no match old value", newRequestWithHeaders("GET", "/", "foo", "bar"), none}, + {"no match diff value", newRequestWithHeaders("GET", "/", "foo", "bak"), none}, + {"no match with none", newRequest("GET", "/"), none}, + {"matches override", newRequestWithHeaders("GET", "/", "foo", "baz"), none}, + }, + }, + { + "method", + func(r *Router) { + r.Methods("POST").Methods("GET") + }, + func(r *Router) { + r.Methods("POST").Subrouter().Methods("GET") + }, + []request{ + {"matches before", newRequest("POST", "/"), none}, + {"no match other", newRequest("HEAD", "/"), none}, + {"matches override", newRequest("GET", "/"), none}, + }, + }, + { + "schemes", + func(r *Router) { + r.Schemes("http").Schemes("https") + }, + func(r *Router) { + r.Schemes("http").Subrouter().Schemes("https") + }, + []request{ + {"matches overrides", newRequest("GET", "https://www.example.com/"), none}, + {"matches original", newRequest("GET", "http://www.example.com/"), none}, + {"no match other", newRequest("GET", "ftp://www.example.com/"), none}, + }, + }, + } + + // case -> request -> router + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + for _, req := range c.Requests { + t.Run(req.Name, func(t *testing.T) { + for _, v := range []struct { + Name string + Config func(*Router) + Expected bool + }{ + {"subrouter", c.Subrouter, (req.Flags & subOnly) != 0}, + {"standard", c.Standard, (req.Flags & stdOnly) != 0}, + } { + r := NewRouter() + v.Config(r) + if r.Match(req.Request, &RouteMatch{}) != v.Expected { + if v.Expected { + t.Errorf("expected %v match", v.Name) + } else { + t.Errorf("expected %v no match", v.Name) + } + } + } + }) + } + }) + } +} + +// verify that copyRouteConf copies fields as expected. +func Test_copyRouteConf(t *testing.T) { + var ( + m MatcherFunc = func(*http.Request, *RouteMatch) bool { + return true + } + b BuildVarsFunc = func(i map[string]string) map[string]string { + return i + } + r, _ = newRouteRegexp("hi", regexpTypeHost, routeRegexpOptions{}) + ) + + tests := []struct { + name string + args routeConf + want routeConf + }{ + { + "empty", + routeConf{}, + routeConf{}, + }, + { + "full", + routeConf{ + useEncodedPath: true, + strictSlash: true, + skipClean: true, + regexp: routeRegexpGroup{host: r, path: r, queries: []*routeRegexp{r}}, + matchers: []matcher{m}, + buildScheme: "https", + buildVarsFunc: b, + }, + routeConf{ + useEncodedPath: true, + strictSlash: true, + skipClean: true, + regexp: routeRegexpGroup{host: r, path: r, queries: []*routeRegexp{r}}, + matchers: []matcher{m}, + buildScheme: "https", + buildVarsFunc: b, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // special case some incomparable fields of routeConf before delegating to reflect.DeepEqual + got := copyRouteConf(tt.args) + + // funcs not comparable, just compare length of slices + if len(got.matchers) != len(tt.want.matchers) { + t.Errorf("matchers different lengths: %v %v", len(got.matchers), len(tt.want.matchers)) + } + got.matchers, tt.want.matchers = nil, nil + + // deep equal treats nil slice differently to empty slice so check for zero len first + { + bothZero := len(got.regexp.queries) == 0 && len(tt.want.regexp.queries) == 0 + if !bothZero && !reflect.DeepEqual(got.regexp.queries, tt.want.regexp.queries) { + t.Errorf("queries unequal: %v %v", got.regexp.queries, tt.want.regexp.queries) + } + got.regexp.queries, tt.want.regexp.queries = nil, nil + } + + // funcs not comparable, just compare nullity + if (got.buildVarsFunc == nil) != (tt.want.buildVarsFunc == nil) { + t.Errorf("build vars funcs unequal: %v %v", got.buildVarsFunc == nil, tt.want.buildVarsFunc == nil) + } + got.buildVarsFunc, tt.want.buildVarsFunc = nil, nil + + // finish the deal + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("route confs unequal: %v %v", got, tt.want) + } + }) + } +} + // mapToPairs converts a string map to a slice of string pairs func mapToPairs(m map[string]string) []string { var i int @@ -2406,3 +2771,28 @@ func newRequest(method, url string) *http.Request { } return req } + +// create a new request with the provided headers +func newRequestWithHeaders(method, url string, headers ...string) *http.Request { + req := newRequest(method, url) + + if len(headers)%2 != 0 { + panic(fmt.Sprintf("Expected headers length divisible by 2 but got %v", len(headers))) + } + + for i := 0; i < len(headers); i += 2 { + req.Header.Set(headers[i], headers[i+1]) + } + + return req +} + +// newRequestHost a new request with a method, url, and host header +func newRequestHost(method, url, host string) *http.Request { + req, err := http.NewRequest(method, url, nil) + if err != nil { + panic(err) + } + req.Host = host + return req +} diff --git a/regexp.go b/regexp.go index b92d59f2..7c7405d1 100644 --- a/regexp.go +++ b/regexp.go @@ -267,7 +267,7 @@ type routeRegexpGroup struct { } // setMatch extracts the variables from the URL once a route matches. -func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) { +func (v routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) { // Store host variables. if v.host != nil { host := getHost(req) diff --git a/route.go b/route.go index c8bb5c7e..acef9195 100644 --- a/route.go +++ b/route.go @@ -15,24 +15,8 @@ import ( // Route stores information to match a request and build URLs. type Route struct { - // Parent where the route was registered (a Router). - parent parentRoute // Request handler for the route. handler http.Handler - // List of matchers. - matchers []matcher - // Manager for the variables from host and path. - regexp *routeRegexpGroup - // If true, when the path pattern is "/path/", accessing "/path" will - // redirect to the former and vice versa. - strictSlash bool - // If true, when the path pattern is "/path//to", accessing "/path//to" - // will not redirect - skipClean bool - // If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to" - useEncodedPath bool - // The scheme used when building URLs. - buildScheme string // If true, this route never matches: it is only used to build URLs. buildOnly bool // The name used to build URLs. @@ -40,7 +24,11 @@ type Route struct { // Error resulted from building a route. err error - buildVarsFunc BuildVarsFunc + // "global" reference to all named routes + namedRoutes map[string]*Route + + // config possibly passed in from `Router` + routeConf } // SkipClean reports whether path cleaning is enabled for this route via @@ -93,9 +81,7 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool { } // Set variables. - if r.regexp != nil { - r.regexp.setMatch(req, match, r) - } + r.regexp.setMatch(req, match, r) return true } @@ -145,7 +131,7 @@ func (r *Route) Name(name string) *Route { } if r.err == nil { r.name = name - r.getNamedRoutes()[name] = r + r.namedRoutes[name] = r } return r } @@ -177,7 +163,6 @@ func (r *Route) addRegexpMatcher(tpl string, typ regexpType) error { if r.err != nil { return r.err } - r.regexp = r.getRegexpGroup() if typ == regexpTypePath || typ == regexpTypePrefix { if len(tpl) > 0 && tpl[0] != '/' { return fmt.Errorf("mux: path must start with a slash, got %q", tpl) @@ -424,7 +409,7 @@ func (r *Route) Schemes(schemes ...string) *Route { for k, v := range schemes { schemes[k] = strings.ToLower(v) } - if r.buildScheme == "" && len(schemes) > 0 { + if len(schemes) > 0 { r.buildScheme = schemes[0] } return r.addMatcher(schemeMatcher(schemes)) @@ -439,7 +424,15 @@ type BuildVarsFunc func(map[string]string) map[string]string // BuildVarsFunc adds a custom function to be used to modify build variables // before a route's URL is built. func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route { - r.buildVarsFunc = f + if r.buildVarsFunc != nil { + // compose the old and new functions + old := r.buildVarsFunc + r.buildVarsFunc = func(m map[string]string) map[string]string { + return f(old(m)) + } + } else { + r.buildVarsFunc = f + } return r } @@ -458,7 +451,8 @@ func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route { // Here, the routes registered in the subrouter won't be tested if the host // doesn't match. func (r *Route) Subrouter() *Router { - router := &Router{parent: r, strictSlash: r.strictSlash} + // initialize a subrouter with a copy of the parent route's configuration + router := &Router{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes} r.addMatcher(router) return router } @@ -502,9 +496,6 @@ func (r *Route) URL(pairs ...string) (*url.URL, error) { if r.err != nil { return nil, r.err } - if r.regexp == nil { - return nil, errors.New("mux: route doesn't have a host or path") - } values, err := r.prepareVars(pairs...) if err != nil { return nil, err @@ -516,8 +507,8 @@ func (r *Route) URL(pairs ...string) (*url.URL, error) { return nil, err } scheme = "http" - if s := r.getBuildScheme(); s != "" { - scheme = s + if r.buildScheme != "" { + scheme = r.buildScheme } } if r.regexp.path != nil { @@ -547,7 +538,7 @@ func (r *Route) URLHost(pairs ...string) (*url.URL, error) { if r.err != nil { return nil, r.err } - if r.regexp == nil || r.regexp.host == nil { + if r.regexp.host == nil { return nil, errors.New("mux: route doesn't have a host") } values, err := r.prepareVars(pairs...) @@ -562,8 +553,8 @@ func (r *Route) URLHost(pairs ...string) (*url.URL, error) { Scheme: "http", Host: host, } - if s := r.getBuildScheme(); s != "" { - u.Scheme = s + if r.buildScheme != "" { + u.Scheme = r.buildScheme } return u, nil } @@ -575,7 +566,7 @@ func (r *Route) URLPath(pairs ...string) (*url.URL, error) { if r.err != nil { return nil, r.err } - if r.regexp == nil || r.regexp.path == nil { + if r.regexp.path == nil { return nil, errors.New("mux: route doesn't have a path") } values, err := r.prepareVars(pairs...) @@ -600,7 +591,7 @@ func (r *Route) GetPathTemplate() (string, error) { if r.err != nil { return "", r.err } - if r.regexp == nil || r.regexp.path == nil { + if r.regexp.path == nil { return "", errors.New("mux: route doesn't have a path") } return r.regexp.path.template, nil @@ -614,7 +605,7 @@ func (r *Route) GetPathRegexp() (string, error) { if r.err != nil { return "", r.err } - if r.regexp == nil || r.regexp.path == nil { + if r.regexp.path == nil { return "", errors.New("mux: route does not have a path") } return r.regexp.path.regexp.String(), nil @@ -629,7 +620,7 @@ func (r *Route) GetQueriesRegexp() ([]string, error) { if r.err != nil { return nil, r.err } - if r.regexp == nil || r.regexp.queries == nil { + if r.regexp.queries == nil { return nil, errors.New("mux: route doesn't have queries") } var queries []string @@ -648,7 +639,7 @@ func (r *Route) GetQueriesTemplates() ([]string, error) { if r.err != nil { return nil, r.err } - if r.regexp == nil || r.regexp.queries == nil { + if r.regexp.queries == nil { return nil, errors.New("mux: route doesn't have queries") } var queries []string @@ -683,7 +674,7 @@ func (r *Route) GetHostTemplate() (string, error) { if r.err != nil { return "", r.err } - if r.regexp == nil || r.regexp.host == nil { + if r.regexp.host == nil { return "", errors.New("mux: route doesn't have a host") } return r.regexp.host.template, nil @@ -700,64 +691,8 @@ func (r *Route) prepareVars(pairs ...string) (map[string]string, error) { } func (r *Route) buildVars(m map[string]string) map[string]string { - if r.parent != nil { - m = r.parent.buildVars(m) - } if r.buildVarsFunc != nil { m = r.buildVarsFunc(m) } return m } - -// ---------------------------------------------------------------------------- -// parentRoute -// ---------------------------------------------------------------------------- - -// parentRoute allows routes to know about parent host and path definitions. -type parentRoute interface { - getBuildScheme() string - getNamedRoutes() map[string]*Route - getRegexpGroup() *routeRegexpGroup - buildVars(map[string]string) map[string]string -} - -func (r *Route) getBuildScheme() string { - if r.buildScheme != "" { - return r.buildScheme - } - if r.parent != nil { - return r.parent.getBuildScheme() - } - return "" -} - -// getNamedRoutes returns the map where named routes are registered. -func (r *Route) getNamedRoutes() map[string]*Route { - if r.parent == nil { - // During tests router is not always set. - r.parent = NewRouter() - } - return r.parent.getNamedRoutes() -} - -// getRegexpGroup returns regexp definitions from this route. -func (r *Route) getRegexpGroup() *routeRegexpGroup { - if r.regexp == nil { - if r.parent == nil { - // During tests router is not always set. - r.parent = NewRouter() - } - regexp := r.parent.getRegexpGroup() - if regexp == nil { - r.regexp = new(routeRegexpGroup) - } else { - // Copy. - r.regexp = &routeRegexpGroup{ - host: regexp.host, - path: regexp.path, - queries: regexp.queries, - } - } - } - return r.regexp -}