Skip to content

Commit 5bc3612

Browse files
committed
Add openapi3 validator middleware
This change introduces middleware which wraps an `http.Handler` with OpenAPI 3 request and response validation.
1 parent 4ce78d8 commit 5bc3612

File tree

2 files changed

+777
-0
lines changed

2 files changed

+777
-0
lines changed

openapi3filter/middleware.go

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
package openapi3filter
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"io/ioutil"
7+
"log"
8+
"net/http"
9+
10+
"github.com/getkin/kin-openapi/routers"
11+
)
12+
13+
// Validator provides HTTP request and response validation middleware.
14+
type Validator struct {
15+
router routers.Router
16+
errFunc ErrFunc
17+
logFunc LogFunc
18+
strict bool
19+
}
20+
21+
// ErrFunc handles errors that may occur during validation.
22+
type ErrFunc func(w http.ResponseWriter, status int, code ErrCode, err error)
23+
24+
// LogFunc handles log messages that may occur during validation.
25+
type LogFunc func(message string, err error)
26+
27+
// ErrCode is used for classification of different types of errors that may
28+
// occur during validation. These may be used to write an appropriate response
29+
// in ErrFunc.
30+
type ErrCode int
31+
32+
const (
33+
// ErrCodeOK indicates no error. It is also the default value.
34+
ErrCodeOK = 0
35+
// ErrCodeCannotFindRoute happens when the validator fails to resolve the
36+
// request to a defined OpenAPI route.
37+
ErrCodeCannotFindRoute = iota
38+
// ErrCodeRequestInvalid happens when the inbound request does not conform
39+
// to the OpenAPI 3 specification.
40+
ErrCodeRequestInvalid = iota
41+
// ErrCodeResponseInvalid happens when the wrapped handler response does
42+
// not conform to the OpenAPI 3 specification.
43+
ErrCodeResponseInvalid = iota
44+
)
45+
46+
func (e ErrCode) responseText() string {
47+
switch e {
48+
case ErrCodeOK:
49+
return "OK"
50+
case ErrCodeCannotFindRoute:
51+
return "not found"
52+
case ErrCodeRequestInvalid:
53+
return "bad request"
54+
default:
55+
return "server error"
56+
}
57+
}
58+
59+
// NewValidator returns a new response validation middlware, using the given
60+
// routes from an OpenAPI 3 specification.
61+
func NewValidator(router routers.Router, options ...ValidatorOption) *Validator {
62+
v := &Validator{
63+
router: router,
64+
errFunc: func(w http.ResponseWriter, status int, code ErrCode, _ error) {
65+
http.Error(w, code.responseText(), status)
66+
},
67+
logFunc: func(message string, err error) {
68+
log.Printf("%s: %v", message, err)
69+
},
70+
}
71+
for i := range options {
72+
options[i](v)
73+
}
74+
return v
75+
}
76+
77+
// ValidatorOption defines an option that may be specified when creating a
78+
// Validator.
79+
type ValidatorOption func(*Validator)
80+
81+
// OnErr provides a callback that handles writing an HTTP response on a
82+
// validation error. This allows customization of error responses without
83+
// prescribing a particular form. This callback is only called on response
84+
// validator errors in Strict mode.
85+
func OnErr(f ErrFunc) ValidatorOption {
86+
return func(v *Validator) {
87+
v.errFunc = f
88+
}
89+
}
90+
91+
// OnLog provides a callback that handles logging in the Validator. This allows
92+
// the validator to integrate with a services' existing logging system without
93+
// prescribing a particular one.
94+
func OnLog(f LogFunc) ValidatorOption {
95+
return func(v *Validator) {
96+
v.logFunc = f
97+
}
98+
}
99+
100+
// Strict, if set, causes an internal server error to be sent if the wrapped
101+
// handler response fails response validation. If not set, the response is sent
102+
// and the error is only logged.
103+
func Strict(strict bool) ValidatorOption {
104+
return func(v *Validator) {
105+
v.strict = strict
106+
}
107+
}
108+
109+
// Middleware returns an http.Handler which wraps the given handler with
110+
// request and response validation.
111+
func (v *Validator) Middleware(h http.Handler) http.Handler {
112+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
113+
route, pathParams, err := v.router.FindRoute(r)
114+
if err != nil {
115+
v.logFunc("validation error: failed to find route for "+r.URL.String(), err)
116+
v.errFunc(w, http.StatusNotFound, ErrCodeCannotFindRoute, err)
117+
return
118+
}
119+
requestValidationInput := &RequestValidationInput{
120+
Request: r,
121+
PathParams: pathParams,
122+
Route: route,
123+
}
124+
if err = ValidateRequest(r.Context(), requestValidationInput); err != nil {
125+
v.logFunc("invalid request", err)
126+
v.errFunc(w, http.StatusBadRequest, ErrCodeRequestInvalid, err)
127+
return
128+
}
129+
130+
var wr responseWrapper
131+
if v.strict {
132+
wr = &strictResponseWrapper{w: w}
133+
} else {
134+
wr = newWarnResponseWrapper(w)
135+
}
136+
137+
h.ServeHTTP(wr, r)
138+
139+
if err = ValidateResponse(r.Context(), &ResponseValidationInput{
140+
RequestValidationInput: requestValidationInput,
141+
Status: wr.statusCode(),
142+
Header: wr.Header(),
143+
Body: ioutil.NopCloser(bytes.NewBuffer(wr.bodyContents())),
144+
}); err != nil {
145+
v.logFunc("invalid response", err)
146+
if v.strict {
147+
v.errFunc(w, http.StatusInternalServerError, ErrCodeResponseInvalid, err)
148+
}
149+
return
150+
}
151+
152+
if err = wr.flushBodyContents(); err != nil {
153+
v.logFunc("failed to write response", err)
154+
}
155+
})
156+
}
157+
158+
type responseWrapper interface {
159+
http.ResponseWriter
160+
161+
// flushBodyContents writes the buffered response to the client, if it has
162+
// not yet been written.
163+
flushBodyContents() error
164+
165+
// statusCode returns the response status code, 0 if not set yet.
166+
statusCode() int
167+
168+
// bodyContents returns the buffered
169+
bodyContents() []byte
170+
}
171+
172+
type warnResponseWrapper struct {
173+
w http.ResponseWriter
174+
headerWritten bool
175+
status int
176+
body bytes.Buffer
177+
tee io.Writer
178+
}
179+
180+
func newWarnResponseWrapper(w http.ResponseWriter) *warnResponseWrapper {
181+
wr := &warnResponseWrapper{
182+
w: w,
183+
}
184+
wr.tee = io.MultiWriter(w, &wr.body)
185+
return wr
186+
}
187+
188+
// Write implements http.ResponseWriter.
189+
func (wr *warnResponseWrapper) Write(b []byte) (int, error) {
190+
if !wr.headerWritten {
191+
wr.WriteHeader(http.StatusOK)
192+
}
193+
return wr.tee.Write(b)
194+
}
195+
196+
// WriteHeader implements http.ResponseWriter.
197+
func (wr *warnResponseWrapper) WriteHeader(status int) {
198+
if !wr.headerWritten {
199+
// If the header hasn't been written, record the status for response
200+
// validation.
201+
wr.status = status
202+
wr.headerWritten = true
203+
}
204+
wr.w.WriteHeader(wr.status)
205+
}
206+
207+
// Header implements http.ResponseWriter.
208+
func (wr *warnResponseWrapper) Header() http.Header {
209+
return wr.w.Header()
210+
}
211+
212+
// Flush implements the optional http.Flusher interface.
213+
func (wr *warnResponseWrapper) Flush() {
214+
// If the wrapped http.ResponseWriter implements optional http.Flusher,
215+
// pass through.
216+
if fl, ok := wr.w.(http.Flusher); ok {
217+
fl.Flush()
218+
}
219+
}
220+
221+
func (wr *warnResponseWrapper) flushBodyContents() error {
222+
return nil
223+
}
224+
225+
func (wr *warnResponseWrapper) statusCode() int {
226+
return wr.status
227+
}
228+
229+
func (wr *warnResponseWrapper) bodyContents() []byte {
230+
return wr.body.Bytes()
231+
}
232+
233+
type strictResponseWrapper struct {
234+
w http.ResponseWriter
235+
headerWritten bool
236+
status int
237+
body bytes.Buffer
238+
}
239+
240+
// Write implements http.ResponseWriter.
241+
func (wr *strictResponseWrapper) Write(b []byte) (int, error) {
242+
if !wr.headerWritten {
243+
wr.WriteHeader(http.StatusOK)
244+
}
245+
return wr.body.Write(b)
246+
}
247+
248+
// WriteHeader implements http.ResponseWriter.
249+
func (wr *strictResponseWrapper) WriteHeader(status int) {
250+
if !wr.headerWritten {
251+
wr.status = status
252+
wr.headerWritten = true
253+
}
254+
}
255+
256+
// Header implements http.ResponseWriter.
257+
func (wr *strictResponseWrapper) Header() http.Header {
258+
return wr.w.Header()
259+
}
260+
261+
func (wr *strictResponseWrapper) flushBodyContents() error {
262+
wr.w.WriteHeader(wr.status)
263+
_, err := wr.w.Write(wr.body.Bytes())
264+
return err
265+
}
266+
267+
func (wr *strictResponseWrapper) statusCode() int {
268+
return wr.status
269+
}
270+
271+
func (wr *strictResponseWrapper) bodyContents() []byte {
272+
return wr.body.Bytes()
273+
}

0 commit comments

Comments
 (0)