Skip to content

Commit 2d7ab94

Browse files
authored
Ensure queries are cancelled correctly via the frontend. (#1508)
* Ensure we cancel queries that traverse the frontend correctly. And test it works. Signed-off-by: Tom Wilkie <tom.wilkie@gmail.com> * Lint: spelling. Signed-off-by: Tom Wilkie <tom.wilkie@gmail.com>
1 parent bc34ecd commit 2d7ab94

File tree

3 files changed

+104
-82
lines changed

3 files changed

+104
-82
lines changed

pkg/querier/frontend/frontend.go

Lines changed: 36 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,10 @@ func (f *Frontend) RoundTripGRPC(ctx context.Context, req *ProcessRequest) (*Pro
210210
request := &request{
211211
request: req,
212212
originalCtx: ctx,
213-
// Buffer of 1 to ensure response can be written even if client has gone away.
213+
214+
// Buffer of 1 to ensure response can be written by the server side
215+
// of the Process stream, even if this goroutine goes away due to
216+
// client context cancellation.
214217
err: make(chan error, 1),
215218
response: make(chan *ProcessResponse, 1),
216219
}
@@ -262,80 +265,56 @@ func (f *Frontend) RoundTripGRPC(ctx context.Context, req *ProcessRequest) (*Pro
262265

263266
// Process allows backends to pull requests from the frontend.
264267
func (f *Frontend) Process(server Frontend_ProcessServer) error {
265-
var (
266-
sendChan = make(chan *ProcessRequest)
267-
recvChan = make(chan *ProcessResponse, 1)
268-
269-
// Need buffer of 2 so goroutines reading/writing to stream don't hang
270-
// around when stream dies.
271-
errChan = make(chan error, 2)
272-
)
273-
274-
// If the stream from the querier is canceled, ping the condition to unblock.
275-
// This is done once, here (instead of in getNextRequest) as we expect calls
276-
// to Process to process many requests.
268+
// If the downstream request(from querier -> frontend) is cancelled,
269+
// we need to ping the condition variable to unblock getNextRequest.
270+
// Ideally we'd have ctx aware condition variables...
277271
go func() {
278272
<-server.Context().Done()
279273
f.cond.Broadcast()
280274
}()
281275

282-
// Use a pair of goroutines to read/write from the stream and send to channels,
283-
// so we can use selects to also wait on the cancellation of the request context.
284-
// These goroutines will error out when the stream returns.
285-
go func() {
286-
for {
287-
var req *ProcessRequest
288-
select {
289-
case req = <-sendChan:
290-
case <-server.Context().Done():
291-
return
292-
}
276+
for {
277+
req, err := f.getNextRequest(server.Context())
278+
if err != nil {
279+
return err
280+
}
293281

294-
err := server.Send(req)
282+
// Handle the stream sending & receiving on a goroutine so we can
283+
// monitoring the contexts in a select and cancel things appropriately.
284+
resps := make(chan *ProcessResponse, 1)
285+
errs := make(chan error, 1)
286+
go func() {
287+
err = server.Send(req.request)
295288
if err != nil {
296-
errChan <- err
289+
errs <- err
297290
return
298291
}
299-
}
300-
}()
301292

302-
go func() {
303-
for {
304293
resp, err := server.Recv()
305-
if err == nil {
306-
recvChan <- resp
307-
} else {
308-
errChan <- err
294+
if err != nil {
295+
errs <- err
309296
return
310297
}
311-
}
312-
}()
313298

314-
for {
315-
request, err := f.getNextRequest(server.Context())
316-
if err != nil {
317-
return err
318-
}
319-
320-
originalCtx := request.originalCtx
299+
resps <- resp
300+
}()
321301

322302
select {
323-
case sendChan <- request.request:
324-
case err := <-errChan:
325-
request.err <- err
303+
// If the upstream reqeust is cancelled, we need to cancel the
304+
// downstream req. Only way we can do that is to close the stream.
305+
// The worker client is expecting this semantics.
306+
case <-req.originalCtx.Done():
307+
return req.originalCtx.Err()
308+
309+
// Is there was an error handling this request due to network IO,
310+
// then error out this upstream request _and_ stream.
311+
case err := <-errs:
312+
req.err <- err
326313
return err
327-
case <-originalCtx.Done():
328-
return originalCtx.Err()
329-
}
330314

331-
select {
332-
case resp := <-recvChan:
333-
request.response <- resp
334-
case err := <-errChan:
335-
request.err <- err
336-
return err
337-
case <-originalCtx.Done():
338-
return originalCtx.Err()
315+
// Happy path: propagate the response.
316+
case resp := <-resps:
317+
req.response <- resp
339318
}
340319
}
341320
}

pkg/querier/frontend/frontend_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99
"sync/atomic"
1010
"testing"
11+
"time"
1112

1213
"github.com/cortexproject/cortex/pkg/util/flagext"
1314
"github.com/go-kit/kit/log"
@@ -124,6 +125,37 @@ func TestFrontendPropagateTrace(t *testing.T) {
124125
testFrontend(t, handler, test)
125126
}
126127

128+
// TestFrontendCancel ensures that when client requests are cancelled,
129+
// the underlying query is correctly cancelled _and not retried_.
130+
func TestFrontendCancel(t *testing.T) {
131+
var tries int32
132+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
133+
<-r.Context().Done()
134+
atomic.AddInt32(&tries, 1)
135+
})
136+
test := func(addr string) {
137+
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/", addr), nil)
138+
require.NoError(t, err)
139+
err = user.InjectOrgIDIntoHTTPRequest(user.InjectOrgID(context.Background(), "1"), req)
140+
require.NoError(t, err)
141+
142+
ctx, cancel := context.WithCancel(context.Background())
143+
req = req.WithContext(ctx)
144+
145+
go func() {
146+
time.Sleep(100 * time.Millisecond)
147+
cancel()
148+
}()
149+
150+
_, err = http.DefaultClient.Do(req)
151+
require.Error(t, err)
152+
153+
time.Sleep(100 * time.Millisecond)
154+
assert.Equal(t, int32(1), atomic.LoadInt32(&tries))
155+
}
156+
testFrontend(t, handler, test)
157+
}
158+
127159
func testFrontend(t *testing.T, handler http.Handler, test func(addr string)) {
128160
logger := log.NewNopLogger() //log.NewLogfmtLogger(log.NewSyncWriter(os.Stderr))
129161

@@ -133,6 +165,7 @@ func testFrontend(t *testing.T, handler http.Handler, test func(addr string)) {
133165
)
134166
flagext.DefaultValues(&config, &workerConfig)
135167
config.SplitQueriesByDay = true
168+
workerConfig.Parallelism = 1
136169

137170
// localhost:0 prevents firewall warnings on Mac OS X.
138171
grpcListen, err := net.Listen("tcp", "localhost:0")

pkg/querier/frontend/worker.go

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ func (w *worker) runOne(ctx context.Context, client FrontendClient) {
173173
continue
174174
}
175175

176-
if err := w.process(ctx, c); err != nil {
176+
if err := w.process(c); err != nil {
177177
level.Error(w.log).Log("msg", "error processing requests", "err", err)
178178
backoff.Wait()
179179
continue
@@ -184,41 +184,51 @@ func (w *worker) runOne(ctx context.Context, client FrontendClient) {
184184
}
185185

186186
// process loops processing requests on an established stream.
187-
func (w *worker) process(ctx context.Context, c Frontend_ProcessClient) error {
187+
func (w *worker) process(c Frontend_ProcessClient) error {
188+
// Build a child context so we can cancel querie when the stream is closed.
189+
ctx, cancel := context.WithCancel(c.Context())
190+
defer cancel()
191+
188192
for {
189193
request, err := c.Recv()
190194
if err != nil {
191195
return err
192196
}
193197

194-
response, err := w.server.Handle(ctx, request.HttpRequest)
195-
if err != nil {
196-
var ok bool
197-
response, ok = httpgrpc.HTTPResponseFromError(err)
198-
if !ok {
199-
response = &httpgrpc.HTTPResponse{
200-
Code: http.StatusInternalServerError,
201-
Body: []byte(err.Error()),
198+
// Handle the request on a "background" goroutine, so we go back to
199+
// blocking on c.Recv(). This allows us to detect the stream closing
200+
// and cancel the query. We don't actally handle queries in parallel
201+
// here, as we're running in lock step with the server - each Recv is
202+
// paired with a Send.
203+
go func() {
204+
response, err := w.server.Handle(ctx, request.HttpRequest)
205+
if err != nil {
206+
var ok bool
207+
response, ok = httpgrpc.HTTPResponseFromError(err)
208+
if !ok {
209+
response = &httpgrpc.HTTPResponse{
210+
Code: http.StatusInternalServerError,
211+
Body: []byte(err.Error()),
212+
}
202213
}
203214
}
204-
}
205-
206-
if len(response.Body) >= w.cfg.GRPCClientConfig.MaxSendMsgSize {
207-
errMsg := fmt.Sprintf("the response is larger than the max (%d vs %d)", len(response.Body), w.cfg.GRPCClientConfig.MaxSendMsgSize)
208215

209-
// This makes sure the request is not retried, else a 500 is sent and we retry the large query again.
210-
response = &httpgrpc.HTTPResponse{
211-
Code: http.StatusRequestEntityTooLarge,
212-
Body: []byte(errMsg),
216+
// Ensure responses that are too big are not retried.
217+
if len(response.Body) >= w.cfg.GRPCClientConfig.MaxSendMsgSize {
218+
errMsg := fmt.Sprintf("response larger than the max (%d vs %d)", len(response.Body), w.cfg.GRPCClientConfig.MaxSendMsgSize)
219+
response = &httpgrpc.HTTPResponse{
220+
Code: http.StatusRequestEntityTooLarge,
221+
Body: []byte(errMsg),
222+
}
223+
level.Error(w.log).Log("msg", "error processing query", "err", errMsg)
213224
}
214-
level.Error(w.log).Log("msg", "error processing query", "err", errMsg)
215-
}
216225

217-
if err := c.Send(&ProcessResponse{
218-
HttpResponse: response,
219-
}); err != nil {
220-
return err
221-
}
226+
if err := c.Send(&ProcessResponse{
227+
HttpResponse: response,
228+
}); err != nil {
229+
level.Error(w.log).Log("msg", "error processing requests", "err", err)
230+
}
231+
}()
222232
}
223233
}
224234

0 commit comments

Comments
 (0)