diff --git a/graphql/handler/transport/http_multipart_mixed.go b/graphql/handler/transport/http_multipart_mixed.go index 362cae5102..bd0f904cf0 100644 --- a/graphql/handler/transport/http_multipart_mixed.go +++ b/graphql/handler/transport/http_multipart_mixed.go @@ -8,6 +8,8 @@ import ( "mime" "net/http" "strings" + "sync" + "time" "github.com/vektah/gqlparser/v2/gqlerror" @@ -16,7 +18,8 @@ import ( // MultipartMixed is a transport that supports the multipart/mixed spec type MultipartMixed struct { - Boundary string + Boundary string + DeliveryTimeout time.Duration } var _ graphql.Transport = MultipartMixed{} @@ -41,6 +44,8 @@ func (t MultipartMixed) Do(w http.ResponseWriter, r *http.Request, exec graphql. // 2022/08/23 as implemented by gqlgen. // * https://github.com/graphql/graphql-wg/blob/f22ea7748c6ebdf88fdbf770a8d9e41984ebd429/rfcs/DeferStream.md June 2023 Spec for the // `incremental` field + // * https://github.com/graphql/graphql-over-http/blob/main/rfcs/IncrementalDelivery.md + // multipart specification // Follows the format that is used in the Apollo Client tests: // https://github.com/apollographql/apollo-client/blob/v3.11.8/src/link/http/__tests__/responseIterator.ts#L68 // Apollo Client, despite mentioning in its requests that they require the 2022 spec, it wants the @@ -61,7 +66,12 @@ func (t MultipartMixed) Do(w http.ResponseWriter, r *http.Request, exec graphql. boundary := t.Boundary if boundary == "" { - boundary = "graphql" + boundary = "-" + } + timeout := t.DeliveryTimeout + if timeout.Milliseconds() < 1 { + // If the timeout is less than 1ms, we'll set it to 1ms to avoid a busy loop + timeout = 1 * time.Millisecond } params := &graphql.RawParams{} @@ -97,31 +107,37 @@ func (t MultipartMixed) Do(w http.ResponseWriter, r *http.Request, exec graphql. rc, opErr := exec.CreateOperationContext(ctx, params) ctx = graphql.WithOperationContext(ctx, rc) + if opErr != nil { + w.WriteHeader(statusFor(opErr)) + + resp := exec.DispatchError(ctx, opErr) + writeJson(w, resp) + return + } - // Example of the response format (note the new lines are important!): + // Example of the response format (note the new lines and boundaries are important!): + // https://github.com/graphql/graphql-over-http/blob/main/rfcs/IncrementalDelivery.md // --graphql // Content-Type: application/json // // {"data":{"apps":{"apps":[ .. ],"totalNumApps":161,"__typename":"AppsOutput"}},"hasNext":true} - // // --graphql // Content-Type: application/json // // {"incremental":[{"data":{"groupAccessCount":0},"label":"test","path":["apps","apps",7],"hasNext":true}],"hasNext":true} - - if opErr != nil { - w.WriteHeader(statusFor(opErr)) - - resp := exec.DispatchError(ctx, opErr) - writeJson(w, resp) - return - } + // --graphql + // ... + // --graphql-- + // Last boundary is a closing boundary with two dashes at the end. w.Header().Set( "Content-Type", fmt.Sprintf(`multipart/mixed;boundary="%s";deferSpec=20220824`, boundary), ) + a := newMultipartResponseAggregator(w, boundary, timeout) + defer a.Done(w) + responses, ctx := exec.DispatchOperation(ctx, rc) initialResponse := true for { @@ -130,27 +146,18 @@ func (t MultipartMixed) Do(w http.ResponseWriter, r *http.Request, exec graphql. break } - fmt.Fprintf(w, "--%s\r\n", boundary) - fmt.Fprintf(w, "Content-Type: application/json\r\n\r\n") - - if initialResponse { - writeJson(w, response) - initialResponse = false - } else { - writeIncrementalJson(w, response, response.HasNext) - } - fmt.Fprintf(w, "\r\n\r\n") - flusher.Flush() + a.Add(response, initialResponse) + initialResponse = false } } -func writeIncrementalJson(w io.Writer, response *graphql.Response, hasNext *bool) { +func writeIncrementalJson(w io.Writer, responses []*graphql.Response, hasNext bool) { // TODO: Remove this wrapper on response once gqlgen supports the 2023 spec b, err := json.Marshal(struct { - Incremental []graphql.Response `json:"incremental"` - HasNext *bool `json:"hasNext"` + Incremental []*graphql.Response `json:"incremental"` + HasNext bool `json:"hasNext"` }{ - Incremental: []graphql.Response{*response}, + Incremental: responses, HasNext: hasNext, }) if err != nil { @@ -158,3 +165,115 @@ func writeIncrementalJson(w io.Writer, response *graphql.Response, hasNext *bool } w.Write(b) } + +func writeBoundary(w io.Writer, boundary string, finalResponse bool) { + if finalResponse { + fmt.Fprintf(w, "--%s--\r\n", boundary) + return + } + fmt.Fprintf(w, "--%s\r\n", boundary) +} + +func writeContentTypeHeader(w io.Writer) { + fmt.Fprintf(w, "Content-Type: application/json\r\n\r\n") +} + +// multipartResponseAggregator helps us reduce the number of responses sent to the frontend by batching all the +// incremental responses together. +type multipartResponseAggregator struct { + mu sync.Mutex + boundary string + initialResponse *graphql.Response + deferResponses []*graphql.Response + done chan bool +} + +// newMultipartResponseAggregator creates a new multipartResponseAggregator +// The aggregator will flush responses to the client every `tickerDuration` (default 1ms) so that +// multiple incremental responses are batched together. +func newMultipartResponseAggregator( + w http.ResponseWriter, + boundary string, + tickerDuration time.Duration, +) *multipartResponseAggregator { + a := &multipartResponseAggregator{ + boundary: boundary, + done: make(chan bool, 1), + } + go func() { + ticker := time.NewTicker(tickerDuration) + defer ticker.Stop() + for { + select { + case <-a.done: + return + case <-ticker.C: + a.flush(w) + } + } + }() + return a +} + +// Done flushes the remaining responses +func (a *multipartResponseAggregator) Done(w http.ResponseWriter) { + a.done <- true + a.flush(w) +} + +// Add accumulates the responses +func (a *multipartResponseAggregator) Add(resp *graphql.Response, initialResponse bool) { + a.mu.Lock() + defer a.mu.Unlock() + if initialResponse { + a.initialResponse = resp + return + } + a.deferResponses = append(a.deferResponses, resp) +} + +// flush sends the accumulated responses to the client +func (a *multipartResponseAggregator) flush(w http.ResponseWriter) { + a.mu.Lock() + defer a.mu.Unlock() + + // If we don't have any responses, we can return early + if a.initialResponse == nil && len(a.deferResponses) == 0 { + return + } + + flusher, ok := w.(http.Flusher) + if !ok { + // This should never happen, as we check for this much earlier on + panic("response writer does not support flushing") + } + + hasNext := false + if a.initialResponse != nil { + // Initial response will need to begin with the boundary + writeBoundary(w, a.boundary, false) + writeContentTypeHeader(w) + + writeJson(w, a.initialResponse) + hasNext = a.initialResponse.HasNext != nil && *a.initialResponse.HasNext + // Reset the initial response so we don't send it again + a.initialResponse = nil + } + + if len(a.deferResponses) > 0 { + writeContentTypeHeader(w) + hasNext = a.deferResponses[len(a.deferResponses)-1].HasNext != nil && + *a.deferResponses[len(a.deferResponses)-1].HasNext + writeIncrementalJson(w, a.deferResponses, hasNext) + // Reset the deferResponses so we don't send them again + a.deferResponses = nil + } + + // Make sure to put the delimiter after every request, so that Apollo Client knows that the + // current payload has been sent, and updates the UI. This is particular important for the first + // response and the last response, which may either hang or never get handled. + // Final response will have a closing boundary with two dashes at the end. + fmt.Fprintf(w, "\r\n") + writeBoundary(w, a.boundary, !hasNext) + flusher.Flush() +} diff --git a/graphql/handler/transport/http_multipart_mixed_test.go b/graphql/handler/transport/http_multipart_mixed_test.go index e590074e77..090a3e96bf 100644 --- a/graphql/handler/transport/http_multipart_mixed_test.go +++ b/graphql/handler/transport/http_multipart_mixed_test.go @@ -19,7 +19,9 @@ import ( func TestMultipartMixed(t *testing.T) { initialize := func() *testserver.TestServer { h := testserver.New() - h.AddTransport(transport.MultipartMixed{}) + h.AddTransport(transport.MultipartMixed{ + Boundary: "graphql", + }) return h } @@ -136,7 +138,6 @@ func TestMultipartMixed(t *testing.T) { "{\"data\":{\"name\":null},\"hasNext\":true}\r\n", readLine(br), ) - assert.Equal(t, "\r\n", readLine(br)) wg.Add(1) go func() { @@ -152,7 +153,8 @@ func TestMultipartMixed(t *testing.T) { "{\"incremental\":[{\"data\":{\"name\":\"test\"},\"hasNext\":false}],\"hasNext\":false}\r\n", readLine(br), ) - assert.Equal(t, "\r\n", readLine(br)) + + assert.Equal(t, "--graphql--\r\n", readLine(br)) wg.Add(1) go func() {