Skip to content

Commit 5a61040

Browse files
sonu27fenollp
andauthored
Match on overridden servers at the path level, fixes #564 (#565)
Co-authored-by: Pierre Fenoll <pierrefenoll@gmail.com>
1 parent 14af893 commit 5a61040

File tree

2 files changed

+125
-66
lines changed

2 files changed

+125
-66
lines changed

routers/gorillamux/router.go

Lines changed: 84 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ type routeMux struct {
3434
varsUpdater varsf
3535
}
3636

37+
type srv struct {
38+
schemes []string
39+
host, base string
40+
server *openapi3.Server
41+
varsUpdater varsf
42+
}
43+
3744
var singleVariableMatcher = regexp.MustCompile(`^\{([^{}]+)\}$`)
3845

3946
// TODO: Handle/HandlerFunc + ServeHTTP (When there is a match, the route variables can be retrieved calling mux.Vars(request))
@@ -42,78 +49,22 @@ var singleVariableMatcher = regexp.MustCompile(`^\{([^{}]+)\}$`)
4249
// Assumes spec is .Validate()d
4350
// Note that a variable for the port number MUST have a default value and only this value will match as the port (see issue #367).
4451
func NewRouter(doc *openapi3.T) (routers.Router, error) {
45-
type srv struct {
46-
schemes []string
47-
host, base string
48-
server *openapi3.Server
49-
varsUpdater varsf
52+
servers, err := makeServers(doc.Servers)
53+
if err != nil {
54+
return nil, err
5055
}
51-
servers := make([]srv, 0, len(doc.Servers))
52-
for _, server := range doc.Servers {
53-
serverURL := server.URL
54-
if submatch := singleVariableMatcher.FindStringSubmatch(serverURL); submatch != nil {
55-
sVar := submatch[1]
56-
sVal := server.Variables[sVar].Default
57-
serverURL = strings.ReplaceAll(serverURL, "{"+sVar+"}", sVal)
58-
var varsUpdater varsf
59-
if lhs := strings.TrimSuffix(serverURL, server.Variables[sVar].Default); lhs != "" {
60-
varsUpdater = func(vars map[string]string) { vars[sVar] = lhs }
61-
}
62-
servers = append(servers, srv{
63-
base: server.Variables[sVar].Default,
64-
server: server,
65-
varsUpdater: varsUpdater,
66-
})
67-
continue
68-
}
69-
70-
var schemes []string
71-
if strings.Contains(serverURL, "://") {
72-
scheme0 := strings.Split(serverURL, "://")[0]
73-
schemes = permutePart(scheme0, server)
74-
serverURL = strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1)
75-
}
7656

77-
// If a variable represents the port "http://domain.tld:{port}/bla"
78-
// then url.Parse() cannot parse "http://domain.tld:`bEncode({port})`/bla"
79-
// and mux is not able to set the {port} variable
80-
// So we just use the default value for this variable.
81-
// See https://github.com/getkin/kin-openapi/issues/367
82-
var varsUpdater varsf
83-
if lhs := strings.Index(serverURL, ":{"); lhs > 0 {
84-
rest := serverURL[lhs+len(":{"):]
85-
rhs := strings.Index(rest, "}")
86-
portVariable := rest[:rhs]
87-
portValue := server.Variables[portVariable].Default
88-
serverURL = strings.ReplaceAll(serverURL, "{"+portVariable+"}", portValue)
89-
varsUpdater = func(vars map[string]string) {
90-
vars[portVariable] = portValue
91-
}
92-
}
93-
94-
u, err := url.Parse(bEncode(serverURL))
95-
if err != nil {
96-
return nil, err
97-
}
98-
path := bDecode(u.EscapedPath())
99-
if len(path) > 0 && path[len(path)-1] == '/' {
100-
path = path[:len(path)-1]
101-
}
102-
servers = append(servers, srv{
103-
host: bDecode(u.Host), //u.Hostname()?
104-
base: path,
105-
schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624
106-
server: server,
107-
varsUpdater: varsUpdater,
108-
})
109-
}
110-
if len(servers) == 0 {
111-
servers = append(servers, srv{})
112-
}
11357
muxRouter := mux.NewRouter().UseEncodedPath()
11458
r := &Router{}
11559
for _, path := range orderedPaths(doc.Paths) {
60+
servers := servers
61+
11662
pathItem := doc.Paths[path]
63+
if len(pathItem.Servers) > 0 {
64+
if servers, err = makeServers(pathItem.Servers); err != nil {
65+
return nil, err
66+
}
67+
}
11768

11869
operations := pathItem.Operations()
11970
methods := make([]string, 0, len(operations))
@@ -177,6 +128,73 @@ func (r *Router) FindRoute(req *http.Request) (*routers.Route, map[string]string
177128
return nil, nil, routers.ErrPathNotFound
178129
}
179130

131+
func makeServers(in openapi3.Servers) ([]srv, error) {
132+
servers := make([]srv, 0, len(in))
133+
for _, server := range in {
134+
serverURL := server.URL
135+
if submatch := singleVariableMatcher.FindStringSubmatch(serverURL); submatch != nil {
136+
sVar := submatch[1]
137+
sVal := server.Variables[sVar].Default
138+
serverURL = strings.ReplaceAll(serverURL, "{"+sVar+"}", sVal)
139+
var varsUpdater varsf
140+
if lhs := strings.TrimSuffix(serverURL, server.Variables[sVar].Default); lhs != "" {
141+
varsUpdater = func(vars map[string]string) { vars[sVar] = lhs }
142+
}
143+
servers = append(servers, srv{
144+
base: server.Variables[sVar].Default,
145+
server: server,
146+
varsUpdater: varsUpdater,
147+
})
148+
continue
149+
}
150+
151+
var schemes []string
152+
if strings.Contains(serverURL, "://") {
153+
scheme0 := strings.Split(serverURL, "://")[0]
154+
schemes = permutePart(scheme0, server)
155+
serverURL = strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1)
156+
}
157+
158+
// If a variable represents the port "http://domain.tld:{port}/bla"
159+
// then url.Parse() cannot parse "http://domain.tld:`bEncode({port})`/bla"
160+
// and mux is not able to set the {port} variable
161+
// So we just use the default value for this variable.
162+
// See https://github.com/getkin/kin-openapi/issues/367
163+
var varsUpdater varsf
164+
if lhs := strings.Index(serverURL, ":{"); lhs > 0 {
165+
rest := serverURL[lhs+len(":{"):]
166+
rhs := strings.Index(rest, "}")
167+
portVariable := rest[:rhs]
168+
portValue := server.Variables[portVariable].Default
169+
serverURL = strings.ReplaceAll(serverURL, "{"+portVariable+"}", portValue)
170+
varsUpdater = func(vars map[string]string) {
171+
vars[portVariable] = portValue
172+
}
173+
}
174+
175+
u, err := url.Parse(bEncode(serverURL))
176+
if err != nil {
177+
return nil, err
178+
}
179+
path := bDecode(u.EscapedPath())
180+
if len(path) > 0 && path[len(path)-1] == '/' {
181+
path = path[:len(path)-1]
182+
}
183+
servers = append(servers, srv{
184+
host: bDecode(u.Host), //u.Hostname()?
185+
base: path,
186+
schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624
187+
server: server,
188+
varsUpdater: varsUpdater,
189+
})
190+
}
191+
if len(servers) == 0 {
192+
servers = append(servers, srv{})
193+
}
194+
195+
return servers, nil
196+
}
197+
180198
func orderedPaths(paths map[string]*openapi3.PathItem) []string {
181199
// https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md#pathsObject
182200
// When matching URLs, concrete (non-templated) paths would be matched

routers/gorillamux/router_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,47 @@ func TestServerPath(t *testing.T) {
254254
require.NoError(t, err)
255255
}
256256

257+
func TestServerOverrideAtPathLevel(t *testing.T) {
258+
helloGET := &openapi3.Operation{Responses: openapi3.NewResponses()}
259+
doc := &openapi3.T{
260+
OpenAPI: "3.0.0",
261+
Info: &openapi3.Info{
262+
Title: "rel",
263+
Version: "1",
264+
},
265+
Servers: openapi3.Servers{
266+
&openapi3.Server{
267+
URL: "https://example.com",
268+
},
269+
},
270+
Paths: openapi3.Paths{
271+
"/hello": &openapi3.PathItem{
272+
Servers: openapi3.Servers{
273+
&openapi3.Server{
274+
URL: "https://another.com",
275+
},
276+
},
277+
Get: helloGET,
278+
},
279+
},
280+
}
281+
err := doc.Validate(context.Background())
282+
require.NoError(t, err)
283+
router, err := NewRouter(doc)
284+
require.NoError(t, err)
285+
286+
req, err := http.NewRequest(http.MethodGet, "https://another.com/hello", nil)
287+
require.NoError(t, err)
288+
route, _, err := router.FindRoute(req)
289+
require.Equal(t, "/hello", route.Path)
290+
291+
req, err = http.NewRequest(http.MethodGet, "https://example.com/hello", nil)
292+
require.NoError(t, err)
293+
route, _, err = router.FindRoute(req)
294+
require.Nil(t, route)
295+
require.Error(t, err)
296+
}
297+
257298
func TestRelativeURL(t *testing.T) {
258299
helloGET := &openapi3.Operation{Responses: openapi3.NewResponses()}
259300
doc := &openapi3.T{

0 commit comments

Comments
 (0)