Skip to content

Commit 5105f51

Browse files
feat!: refactor ws response writer to add validation logic (#33)
* refactor!: ws metrics response writer * chore: remove unused interface. * feat: add validate graphql request method * chore: change varname * chore: update .golangci.yml * chore: returning error should be `io.EOF`
1 parent 1a8cfed commit 5105f51

File tree

7 files changed

+182
-99
lines changed

7 files changed

+182
-99
lines changed

.air.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ tmp_dir = "tmp"
66
bin = "./tmp/gbox run -config ./Caddyfile.dist -watch"
77
cmd = "go build -o ./tmp/gbox ./cmd"
88
delay = 1000
9-
exclude_dir = ["assets", "tmp", "vendor", "testdata", "internal/testserver"]
9+
exclude_dir = ["assets", "tmp", "vendor", "testdata", "internal/testserver", "charts", "dist"]
1010
exclude_file = []
1111
exclude_regex = ["_test.go"]
1212
exclude_unchanged = false

.golangci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ linters:
3232
- wrapcheck
3333
- goerr113
3434
- gochecknoglobals
35+
- execinquery
36+
- exhaustruct
37+
- nonamedreturns
3538

3639
# deprecated
3740
- interfacer

handler_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ func (s *HandlerIntegrationTestSuite) TestIntrospection() {
110110
"disabled": {
111111
extraConfig: "disabled_introspection true",
112112
payload: `{"query": "query { __schema { queryType { name } } }"}`,
113-
expectedBody: `{"errors":[{"message":"introspection queries are not allowed"}]}`,
113+
expectedBody: `{"errors":[{"message":"introspection query is not allowed"}]}`,
114114
},
115115
}
116116

metrics.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,6 @@ type Metrics struct {
5656
cachingCount *prometheus.CounterVec
5757
}
5858

59-
type requestMetrics interface {
60-
addMetricsBeginRequest(*graphql.Request)
61-
addMetricsEndRequest(*graphql.Request, time.Duration)
62-
}
63-
6459
type cachingMetrics interface {
6560
addMetricsCaching(*graphql.Request, CachingStatus)
6661
}

router.go

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ const (
2525
graphQLPath = "/graphql"
2626
)
2727

28+
var ErrNotAllowIntrospectionQuery = errors.New("introspection query is not allowed")
29+
2830
func (h *Handler) initRouter() {
2931
router := mux.NewRouter()
3032
router.Path(graphQLPath).HeadersRegexp(
@@ -75,8 +77,8 @@ func (h *Handler) GraphQLOverWebsocketHandle(w http.ResponseWriter, r *http.Requ
7577
}
7678

7779
n := r.Context().Value(nextHandlerCtxKey).(caddyhttp.Handler)
78-
mr := newWebsocketMetricsResponseWriter(w, h.schema, h)
79-
reporter.error = h.ReverseProxy.ServeHTTP(mr, r, n)
80+
wsr := newWebsocketResponseWriter(w, h)
81+
reporter.error = h.ReverseProxy.ServeHTTP(wsr, r, n)
8082
}
8183

8284
// GraphQLHandle ensure GraphQL request is safe before forwarding to upstream and caching query result of it.
@@ -91,36 +93,23 @@ func (h *Handler) GraphQLHandle(w http.ResponseWriter, r *http.Request) {
9193

9294
gqlRequest, err := h.unmarshalHTTPRequest(r)
9395
if err != nil {
94-
h.logger.Debug("unmarshal GQL cachingRequest from http cachingRequest failure", zap.Error(err))
96+
h.logger.Debug("can not unmarshal graphql request from http request", zap.Error(err))
9597
reporter.error = writeResponseErrors(err, w)
9698

9799
return
98100
}
99101

100-
isIntrospectQuery, _ := gqlRequest.IsIntrospectionQuery()
101-
start := time.Now()
102-
103102
h.addMetricsBeginRequest(gqlRequest)
104-
defer func() {
105-
h.addMetricsEndRequest(gqlRequest, time.Since(start))
106-
}()
103+
defer func(startedAt time.Time) {
104+
h.addMetricsEndRequest(gqlRequest, time.Since(startedAt))
105+
}(time.Now())
107106

108-
if isIntrospectQuery && h.DisabledIntrospection {
109-
reporter.error = writeResponseErrors(errors.New("introspection queries are not allowed"), w)
107+
if err = h.validateGraphqlRequest(gqlRequest); err != nil {
108+
reporter.error = writeResponseErrors(err, w)
110109

111110
return
112111
}
113112

114-
if h.Complexity != nil {
115-
requestErrors := h.Complexity.validateRequest(h.schema, gqlRequest)
116-
117-
if requestErrors.Count() > 0 {
118-
reporter.error = writeResponseErrors(requestErrors, w)
119-
120-
return
121-
}
122-
}
123-
124113
n := r.Context().Value(nextHandlerCtxKey).(caddyhttp.Handler)
125114

126115
if h.Caching != nil {
@@ -161,6 +150,24 @@ func (h *Handler) unmarshalHTTPRequest(r *http.Request) (*graphql.Request, error
161150
return gqlRequest, nil
162151
}
163152

153+
func (h *Handler) validateGraphqlRequest(r *graphql.Request) error {
154+
isIntrospectQuery, _ := r.IsIntrospectionQuery()
155+
156+
if isIntrospectQuery && h.DisabledIntrospection {
157+
return ErrNotAllowIntrospectionQuery
158+
}
159+
160+
if h.Complexity != nil {
161+
requestErrors := h.Complexity.validateRequest(h.schema, r)
162+
163+
if requestErrors.Count() > 0 {
164+
return requestErrors
165+
}
166+
}
167+
168+
return nil
169+
}
170+
164171
// AdminGraphQLHandle purging query result cached and describe cache key.
165172
func (h *Handler) AdminGraphQLHandle(w http.ResponseWriter, r *http.Request) {
166173
resolver := admin.NewResolver(h.schema, h.schemaDocument, h.logger, h.Caching)

ws.go

Lines changed: 88 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bufio"
55
"bytes"
66
"encoding/json"
7+
"io"
78
"net"
89
"net/http"
910
"time"
@@ -13,54 +14,79 @@ import (
1314
"github.com/jensneuse/graphql-go-tools/pkg/graphql"
1415
)
1516

16-
type wsMetricsResponseWriter struct {
17-
requestMetrics
17+
type wsSubscriber interface {
18+
onWsSubscribe(*graphql.Request) error
19+
onWsClose(*graphql.Request, time.Duration)
20+
}
21+
22+
func (h *Handler) onWsSubscribe(r *graphql.Request) (err error) {
23+
if err = normalizeGraphqlRequest(h.schema, r); err != nil {
24+
return err
25+
}
26+
27+
if err = h.validateGraphqlRequest(r); err != nil {
28+
return err
29+
}
30+
31+
h.addMetricsBeginRequest(r)
32+
33+
return nil
34+
}
35+
36+
func (h *Handler) onWsClose(r *graphql.Request, d time.Duration) {
37+
h.addMetricsEndRequest(r, d)
38+
}
39+
40+
type wsResponseWriter struct {
1841
*caddyhttp.ResponseWriterWrapper
19-
schema *graphql.Schema
42+
subscriber wsSubscriber
2043
}
2144

22-
func newWebsocketMetricsResponseWriter(w http.ResponseWriter, s *graphql.Schema, rm requestMetrics) *wsMetricsResponseWriter {
23-
return &wsMetricsResponseWriter{
45+
func newWebsocketResponseWriter(w http.ResponseWriter, s wsSubscriber) *wsResponseWriter {
46+
return &wsResponseWriter{
2447
ResponseWriterWrapper: &caddyhttp.ResponseWriterWrapper{
2548
ResponseWriter: w,
2649
},
27-
schema: s,
28-
requestMetrics: rm,
50+
subscriber: s,
2951
}
3052
}
3153

32-
// Hijack connection for collecting subscription metrics.
33-
func (r *wsMetricsResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
54+
// Hijack connection for validating, and collecting metrics.
55+
func (r *wsResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
3456
c, w, e := r.ResponseWriterWrapper.Hijack()
3557

3658
if c != nil {
37-
c = &wsMetricsConn{
38-
Conn: c,
39-
requestMetrics: r.requestMetrics,
40-
schema: r.schema,
59+
c = &wsConn{
60+
Conn: c,
61+
wsSubscriber: r.subscriber,
4162
}
4263
}
4364

4465
return c, w, e
4566
}
4667

47-
type wsMetricsConn struct {
68+
type wsConn struct {
4869
net.Conn
49-
requestMetrics
70+
wsSubscriber
5071
request *graphql.Request
51-
schema *graphql.Schema
5272
subscribeAt time.Time
5373
}
5474

55-
func (c *wsMetricsConn) Read(b []byte) (n int, err error) {
75+
type wsMessage struct {
76+
ID interface{} `json:"id"`
77+
Type string `json:"type"`
78+
Payload json.RawMessage `json:"payload,omitempty"`
79+
}
80+
81+
func (c *wsConn) Read(b []byte) (n int, err error) {
5682
n, err = c.Conn.Read(b)
5783

58-
if c.request != nil || err != nil {
59-
if err != nil {
60-
c.addMetricsEndRequest(c.request, time.Since(c.subscribeAt))
61-
c.request = nil
62-
}
84+
if c.request != nil && err != nil {
85+
c.onWsClose(c.request, time.Since(c.subscribeAt))
86+
c.request = nil
87+
}
6388

89+
if c.request != nil || err != nil {
6490
return n, err
6591
}
6692

@@ -76,10 +102,7 @@ func (c *wsMetricsConn) Read(b []byte) (n int, err error) {
76102
}
77103

78104
decoder := json.NewDecoder(r)
79-
msg := &struct {
80-
Type string `json:"type"`
81-
Payload json.RawMessage `json:"payload"`
82-
}{}
105+
msg := &wsMessage{}
83106

84107
if e := decoder.Decode(msg); e != nil {
85108
return n, err
@@ -92,14 +115,50 @@ func (c *wsMetricsConn) Read(b []byte) (n int, err error) {
92115
return n, err
93116
}
94117

95-
if e := normalizeGraphqlRequest(c.schema, request); e != nil {
96-
return n, err
118+
if e := c.onWsSubscribe(request); e != nil {
119+
c.writeErrorMessage(msg.ID, e)
120+
c.writeCompleteMessage(msg.ID)
121+
122+
return n, io.EOF
97123
}
98124

99125
c.request = request
100126
c.subscribeAt = time.Now()
101-
c.addMetricsBeginRequest(request)
102127
}
103128

104129
return n, err
105130
}
131+
132+
func (c *wsConn) writeErrorMessage(id interface{}, errMsg error) error {
133+
errMsgRaw, errMsgErr := json.Marshal(graphql.RequestErrorsFromError(errMsg))
134+
135+
if errMsgErr != nil {
136+
return errMsgErr
137+
}
138+
139+
msg := &wsMessage{
140+
ID: id,
141+
Type: "error",
142+
Payload: json.RawMessage(errMsgRaw),
143+
}
144+
145+
payload, err := json.Marshal(msg)
146+
if err != nil {
147+
return err
148+
}
149+
150+
return wsutil.WriteServerText(c, payload)
151+
}
152+
153+
func (c *wsConn) writeCompleteMessage(id interface{}) error {
154+
msg := &wsMessage{
155+
ID: id,
156+
Type: "complete",
157+
}
158+
payload, err := json.Marshal(msg)
159+
if err != nil {
160+
return err
161+
}
162+
163+
return wsutil.WriteServerText(c, payload)
164+
}

0 commit comments

Comments
 (0)