Skip to content

Commit 3a6247b

Browse files
committed
Add openapi3 validator middleware
This change introduces middleware which wraps an `http.Handler` with OpenAPI 3 request and response validation.
1 parent ceb64e7 commit 3a6247b

File tree

2 files changed

+656
-0
lines changed

2 files changed

+656
-0
lines changed

openapi3filter/middleware.go

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
package openapi3filter
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"log"
7+
"net/http"
8+
9+
"github.com/getkin/kin-openapi/routers"
10+
)
11+
12+
// Validator provides HTTP request and response validation middleware.
13+
type Validator struct {
14+
router routers.Router
15+
errFunc ErrFunc
16+
logFunc LogFunc
17+
strict bool
18+
}
19+
20+
// ErrFunc handles errors that may occur during validation.
21+
type ErrFunc func(w http.ResponseWriter, status int, code ErrCode, err error)
22+
23+
// LogFunc handles log messages that may occur during validation.
24+
type LogFunc func(message string, err error)
25+
26+
// ErrCode is used for classification of different types of errors that may
27+
// occur during validation. These may be used to write an appropriate response
28+
// in ErrFunc.
29+
type ErrCode int
30+
31+
const (
32+
// ErrCodeOK indicates no error. It is also the default value.
33+
ErrCodeOK = 0
34+
// ErrCodeCannotFindRoute happens when the validator fails to resolve the
35+
// request to a defined OpenAPI route.
36+
ErrCodeCannotFindRoute = iota
37+
// ErrCodeRequestInvalid happens when the inbound request does not conform
38+
// to the OpenAPI 3 specification.
39+
ErrCodeRequestInvalid = iota
40+
// ErrCodeResponseInvalid happens when the wrapped handler response does
41+
// not conform to the OpenAPI 3 specification.
42+
ErrCodeResponseInvalid = iota
43+
)
44+
45+
func (e ErrCode) responseText() string {
46+
switch e {
47+
case ErrCodeOK:
48+
return "OK"
49+
case ErrCodeCannotFindRoute:
50+
return "not found"
51+
case ErrCodeRequestInvalid:
52+
return "bad request"
53+
default:
54+
return "server error"
55+
}
56+
}
57+
58+
// NewValidator returns a new response validation middlware, using the given
59+
// routes from an OpenAPI 3 specification.
60+
func NewValidator(router routers.Router, options ...ValidatorOption) *Validator {
61+
v := &Validator{
62+
router: router,
63+
errFunc: func(w http.ResponseWriter, status int, code ErrCode, _ error) {
64+
http.Error(w, code.responseText(), status)
65+
},
66+
logFunc: func(message string, err error) {
67+
log.Printf("%s: %v", message, err)
68+
},
69+
}
70+
for i := range options {
71+
options[i](v)
72+
}
73+
return v
74+
}
75+
76+
// ValidatorOption defines an option that may be specified when creating a
77+
// Validator.
78+
type ValidatorOption func(*Validator)
79+
80+
// OnErr provides a callback that handles writing an HTTP response on a
81+
// validation error. This allows customization of error responses without
82+
// prescribing a particular form. This callback is only called on response
83+
// validator errors in Strict mode.
84+
func OnErr(f ErrFunc) ValidatorOption {
85+
return func(v *Validator) {
86+
v.errFunc = f
87+
}
88+
}
89+
90+
// OnLog provides a callback that handles logging in the Validator. This allows
91+
// the validator to integrate with a services' existing logging system without
92+
// prescribing a particular one.
93+
func OnLog(f LogFunc) ValidatorOption {
94+
return func(v *Validator) {
95+
v.logFunc = f
96+
}
97+
}
98+
99+
// Strict, if set, causes an internal server error to be sent if the wrapped
100+
// handler response fails response validation. If not set, the response is sent
101+
// and the error is only logged.
102+
func Strict(strict bool) ValidatorOption {
103+
return func(v *Validator) {
104+
v.strict = strict
105+
}
106+
}
107+
108+
// Middleware returns an http.Handler which wraps the given handler with
109+
// request and response validation.
110+
func (v *Validator) Middleware(h http.Handler) http.Handler {
111+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
112+
route, pathParams, err := v.router.FindRoute(r)
113+
if err != nil {
114+
v.logFunc("validation error: failed to find route for "+r.URL.String(), err)
115+
v.errFunc(w, http.StatusNotFound, ErrCodeCannotFindRoute, err)
116+
return
117+
}
118+
requestValidationInput := &RequestValidationInput{
119+
Request: r,
120+
PathParams: pathParams,
121+
Route: route,
122+
}
123+
err = ValidateRequest(r.Context(), requestValidationInput)
124+
if err != nil {
125+
v.logFunc("invalid request", err)
126+
v.errFunc(w, http.StatusBadRequest, ErrCodeRequestInvalid, err)
127+
return
128+
}
129+
130+
var wr responseWrapper
131+
if v.strict {
132+
wr = &strictResponseWrapper{w: w}
133+
} else {
134+
wr = &warnResponseWrapper{w: w}
135+
}
136+
137+
h.ServeHTTP(wr, r)
138+
139+
err = ValidateResponse(r.Context(), &ResponseValidationInput{
140+
RequestValidationInput: requestValidationInput,
141+
Status: wr.statusCode(),
142+
Header: wr.Header(),
143+
Body: nopCloser{bytes.NewBuffer(wr.bodyContents())},
144+
})
145+
if err != nil {
146+
v.logFunc("invalid response", err)
147+
if v.strict {
148+
v.errFunc(w, http.StatusInternalServerError, ErrCodeResponseInvalid, err)
149+
}
150+
return
151+
}
152+
153+
err = wr.flushBodyContents()
154+
if err != nil {
155+
v.logFunc("failed to write response", err)
156+
}
157+
})
158+
}
159+
160+
type nopCloser struct {
161+
io.Reader
162+
}
163+
164+
// Close implements io.Closer.
165+
func (nopCloser) Close() error {
166+
return nil
167+
}
168+
169+
type responseWrapper interface {
170+
http.ResponseWriter
171+
172+
// flushBodyContents writes the buffered response to the client, if it has
173+
// not yet been written.
174+
flushBodyContents() error
175+
176+
// statusCode returns the response status code, 0 if not set yet.
177+
statusCode() int
178+
179+
// bodyContents returns the buffered
180+
bodyContents() []byte
181+
}
182+
183+
type warnResponseWrapper struct {
184+
w http.ResponseWriter
185+
fl http.Flusher
186+
status int
187+
body bytes.Buffer
188+
}
189+
190+
// Write implements http.ResponseWriter.
191+
func (l *warnResponseWrapper) Write(b []byte) (int, error) {
192+
if l.status == 0 {
193+
l.w.WriteHeader(l.status)
194+
}
195+
n, err := l.w.Write(b)
196+
if err == nil {
197+
l.body.Write(b)
198+
}
199+
return n, err
200+
}
201+
202+
// WriteHeader implements http.ResponseWriter.
203+
func (l *warnResponseWrapper) WriteHeader(status int) {
204+
if l.status == 0 {
205+
l.status = status
206+
}
207+
l.w.WriteHeader(l.status)
208+
}
209+
210+
// Header implements http.ResponseWriter.
211+
func (wr *warnResponseWrapper) Header() http.Header {
212+
return wr.w.Header()
213+
}
214+
215+
// Flush implements the optional http.Flusher interface.
216+
func (wr *warnResponseWrapper) Flush() {
217+
if fl, ok := wr.w.(http.Flusher); ok {
218+
fl.Flush()
219+
}
220+
}
221+
222+
func (l *warnResponseWrapper) flushBodyContents() error {
223+
return nil
224+
}
225+
226+
func (l *warnResponseWrapper) statusCode() int {
227+
return l.status
228+
}
229+
230+
func (l *warnResponseWrapper) bodyContents() []byte {
231+
return l.body.Bytes()
232+
}
233+
234+
type strictResponseWrapper struct {
235+
w http.ResponseWriter
236+
status int
237+
body bytes.Buffer
238+
}
239+
240+
// Write implements http.ResponseWriter.
241+
func (wr *strictResponseWrapper) Write(b []byte) (int, error) {
242+
if wr.status == 0 {
243+
wr.status = http.StatusOK
244+
}
245+
return wr.body.Write(b)
246+
}
247+
248+
// WriteHeader implements http.ResponseWriter.
249+
func (wr *strictResponseWrapper) WriteHeader(status int) {
250+
if wr.status == 0 {
251+
wr.status = status
252+
}
253+
}
254+
255+
// Header implements http.ResponseWriter.
256+
func (wr *strictResponseWrapper) Header() http.Header {
257+
return wr.w.Header()
258+
}
259+
260+
func (wr *strictResponseWrapper) flushBodyContents() error {
261+
wr.w.WriteHeader(wr.status)
262+
_, err := wr.w.Write(wr.body.Bytes())
263+
return err
264+
}
265+
266+
func (wr *strictResponseWrapper) statusCode() int {
267+
return wr.status
268+
}
269+
270+
func (wr *strictResponseWrapper) bodyContents() []byte {
271+
return wr.body.Bytes()
272+
}

0 commit comments

Comments
 (0)