Skip to content

Commit 182c73b

Browse files
cyriltovenasimonswinepracucci
authored
Correctly apply parallelism limits when doing querysharding. (#253)
* Correctly apply parallelism limits when doing querysharding. Previously we were only splitting by day/time in the frontend, so applying the max parallelism was easy. Now we can also apply parallelism at the querysharding level, and this means we can easily by-pass the `MaxQueryParallelism` limits. This PR apply the limits at a lower level and so fixes the problem of overscheduling per query which we currently have since querysharding is activable. This is inspired by work we've done in Loki. Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com> * Update pkg/querier/queryrange/limits.go Co-authored-by: Christian Simon <simon@swine.de> * Update pkg/querier/queryrange/limits_test.go Co-authored-by: Christian Simon <simon@swine.de> * Checks for error in tests. Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com> * Improve concurrency handling. Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com> * Added unit test on context cancellation Signed-off-by: Marco Pracucci <marco@pracucci.com> * Simplify the code. Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com> Co-authored-by: Christian Simon <simon@swine.de> Co-authored-by: Marco Pracucci <marco@pracucci.com>
1 parent c5e794b commit 182c73b

File tree

3 files changed

+288
-13
lines changed

3 files changed

+288
-13
lines changed

pkg/querier/queryrange/limits.go

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ package queryrange
88
import (
99
"context"
1010
"net/http"
11+
"sync"
1112
"time"
1213

1314
"github.com/go-kit/kit/log/level"
15+
"github.com/opentracing/opentracing-go"
1416
"github.com/prometheus/prometheus/pkg/timestamp"
1517
"github.com/weaveworks/common/httpgrpc"
1618

@@ -103,3 +105,116 @@ func (l limitsMiddleware) Do(ctx context.Context, r Request) (Response, error) {
103105

104106
return l.next.Do(ctx, r)
105107
}
108+
109+
type limitedRoundTripper struct {
110+
downstream Handler
111+
limits Limits
112+
113+
codec Codec
114+
middleware Middleware
115+
}
116+
117+
// NewLimitedRoundTripper creates a new roundtripper that enforces MaxQueryParallelism to the `next` roundtripper across `middlewares`.
118+
func NewLimitedRoundTripper(next http.RoundTripper, codec Codec, limits Limits, middlewares ...Middleware) http.RoundTripper {
119+
transport := limitedRoundTripper{
120+
downstream: roundTripperHandler{
121+
next: next,
122+
codec: codec,
123+
},
124+
codec: codec,
125+
limits: limits,
126+
middleware: MergeMiddlewares(middlewares...),
127+
}
128+
return transport
129+
}
130+
131+
type subRequest struct {
132+
req Request
133+
ctx context.Context
134+
result chan result
135+
}
136+
137+
type result struct {
138+
response Response
139+
err error
140+
}
141+
142+
func newSubRequest(ctx context.Context, req Request) subRequest {
143+
return subRequest{
144+
req: req,
145+
ctx: ctx,
146+
result: make(chan result, 1),
147+
}
148+
}
149+
150+
func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
151+
var (
152+
wg sync.WaitGroup
153+
intermediate = make(chan subRequest)
154+
ctx, cancel = context.WithCancel(r.Context())
155+
)
156+
defer func() {
157+
cancel()
158+
wg.Wait()
159+
}()
160+
161+
request, err := rt.codec.DecodeRequest(ctx, r)
162+
if err != nil {
163+
return nil, err
164+
}
165+
166+
if span := opentracing.SpanFromContext(ctx); span != nil {
167+
request.LogToSpan(span)
168+
}
169+
170+
tenantIDs, err := tenant.TenantIDs(ctx)
171+
if err != nil {
172+
return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error())
173+
}
174+
175+
// Creates workers that will process the sub-requests in parallel for this query.
176+
// The amount of workers is limited by the MaxQueryParallelism tenant setting.
177+
parallelism := validation.SmallestPositiveIntPerTenant(tenantIDs, rt.limits.MaxQueryParallelism)
178+
for i := 0; i < parallelism; i++ {
179+
wg.Add(1)
180+
go func() {
181+
defer wg.Done()
182+
for {
183+
select {
184+
case w := <-intermediate:
185+
resp, err := rt.downstream.Do(w.ctx, w.req)
186+
w.result <- result{response: resp, err: err}
187+
case <-ctx.Done():
188+
return
189+
}
190+
}
191+
}()
192+
}
193+
194+
// Wraps middlewares with a final handler, which will receive requests in
195+
// parallel from upstream handlers. Then each requests gets scheduled to a
196+
// different worker via the `intermediate` channel, so the maximum
197+
// parallelism is limited. This worker will then call `Do` on the resulting
198+
// handler.
199+
response, err := rt.middleware.Wrap(
200+
HandlerFunc(func(ctx context.Context, r Request) (Response, error) {
201+
s := newSubRequest(ctx, r)
202+
select {
203+
case intermediate <- s:
204+
case <-ctx.Done():
205+
return nil, ctx.Err()
206+
}
207+
208+
select {
209+
case response := <-s.result:
210+
return response.response, response.err
211+
case <-ctx.Done():
212+
return nil, ctx.Err()
213+
}
214+
})).Do(ctx, request)
215+
if err != nil {
216+
return nil, err
217+
}
218+
219+
return rt.codec.EncodeResponse(ctx, response)
220+
}

pkg/querier/queryrange/limits_test.go

Lines changed: 159 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@ package queryrange
77

88
import (
99
"context"
10+
"net/http"
11+
"sync"
1012
"testing"
1113
"time"
1214

1315
"github.com/stretchr/testify/assert"
1416
"github.com/stretchr/testify/mock"
1517
"github.com/stretchr/testify/require"
1618
"github.com/weaveworks/common/user"
19+
"go.uber.org/atomic"
1720

1821
"github.com/grafana/mimir/pkg/util"
1922
)
@@ -192,10 +195,11 @@ func TestLimitsMiddleware_MaxQueryLength(t *testing.T) {
192195
}
193196

194197
type mockLimits struct {
195-
maxQueryLookback time.Duration
196-
maxQueryLength time.Duration
197-
maxCacheFreshness time.Duration
198-
totalShards int
198+
maxQueryLookback time.Duration
199+
maxQueryLength time.Duration
200+
maxCacheFreshness time.Duration
201+
maxQueryParallelism int
202+
totalShards int
199203
}
200204

201205
func (m mockLimits) MaxQueryLookback(string) time.Duration {
@@ -206,8 +210,11 @@ func (m mockLimits) MaxQueryLength(string) time.Duration {
206210
return m.maxQueryLength
207211
}
208212

209-
func (mockLimits) MaxQueryParallelism(string) int {
210-
return 14 // Flag default.
213+
func (m mockLimits) MaxQueryParallelism(string) int {
214+
if m.maxQueryParallelism == 0 {
215+
return 14 // Flag default.
216+
}
217+
return m.maxQueryParallelism
211218
}
212219

213220
func (m mockLimits) MaxCacheFreshness(string) time.Duration {
@@ -226,3 +233,149 @@ func (m *mockHandler) Do(ctx context.Context, req Request) (Response, error) {
226233
args := m.Called(ctx, req)
227234
return args.Get(0).(Response), args.Error(1)
228235
}
236+
237+
func TestLimitedRoundTripper_MaxQueryParallelism(t *testing.T) {
238+
var (
239+
maxQueryParallelism = 2
240+
count atomic.Int32
241+
max atomic.Int32
242+
downstream = RoundTripFunc(func(_ *http.Request) (*http.Response, error) {
243+
cur := count.Inc()
244+
if cur > max.Load() {
245+
max.Store(cur)
246+
}
247+
defer count.Dec()
248+
// simulate some work
249+
time.Sleep(20 * time.Millisecond)
250+
return &http.Response{
251+
Body: http.NoBody,
252+
}, nil
253+
})
254+
ctx = user.InjectOrgID(context.Background(), "foo")
255+
)
256+
257+
r, err := PrometheusCodec.EncodeRequest(ctx, &PrometheusRequest{
258+
Path: "/query_range",
259+
Start: time.Now().Add(time.Hour).Unix(),
260+
End: util.TimeToMillis(time.Now()),
261+
Step: int64(1 * time.Second * time.Millisecond),
262+
Query: `foo`,
263+
})
264+
require.Nil(t, err)
265+
266+
_, err = NewLimitedRoundTripper(downstream, PrometheusCodec, mockLimits{maxQueryParallelism: maxQueryParallelism},
267+
MiddlewareFunc(func(next Handler) Handler {
268+
return HandlerFunc(func(c context.Context, _ Request) (Response, error) {
269+
var wg sync.WaitGroup
270+
for i := 0; i < maxQueryParallelism+20; i++ {
271+
wg.Add(1)
272+
go func() {
273+
defer wg.Done()
274+
_, _ = next.Do(c, &PrometheusRequest{})
275+
}()
276+
}
277+
wg.Wait()
278+
return NewEmptyPrometheusResponse(), nil
279+
})
280+
}),
281+
).RoundTrip(r)
282+
require.NoError(t, err)
283+
maxFound := int(max.Load())
284+
require.LessOrEqual(t, maxFound, maxQueryParallelism, "max query parallelism: ", maxFound, " went over the configured one:", maxQueryParallelism)
285+
}
286+
287+
func TestLimitedRoundTripper_MaxQueryParallelismLateScheduling(t *testing.T) {
288+
var (
289+
maxQueryParallelism = 2
290+
downstream = RoundTripFunc(func(_ *http.Request) (*http.Response, error) {
291+
// simulate some work
292+
time.Sleep(20 * time.Millisecond)
293+
return &http.Response{
294+
Body: http.NoBody,
295+
}, nil
296+
})
297+
ctx = user.InjectOrgID(context.Background(), "foo")
298+
)
299+
300+
r, err := PrometheusCodec.EncodeRequest(ctx, &PrometheusRequest{
301+
Path: "/query_range",
302+
Start: time.Now().Add(time.Hour).Unix(),
303+
End: util.TimeToMillis(time.Now()),
304+
Step: int64(1 * time.Second * time.Millisecond),
305+
Query: `foo`,
306+
})
307+
require.Nil(t, err)
308+
309+
_, err = NewLimitedRoundTripper(downstream, PrometheusCodec, mockLimits{maxQueryParallelism: maxQueryParallelism},
310+
MiddlewareFunc(func(next Handler) Handler {
311+
return HandlerFunc(func(c context.Context, _ Request) (Response, error) {
312+
// fire up work and we don't wait.
313+
for i := 0; i < 10; i++ {
314+
go func() {
315+
_, _ = next.Do(c, &PrometheusRequest{})
316+
}()
317+
}
318+
return NewEmptyPrometheusResponse(), nil
319+
})
320+
}),
321+
).RoundTrip(r)
322+
require.NoError(t, err)
323+
}
324+
325+
func TestLimitedRoundTripper_OriginalRequestContextCancellation(t *testing.T) {
326+
var (
327+
maxQueryParallelism = 2
328+
downstream = RoundTripFunc(func(req *http.Request) (*http.Response, error) {
329+
// Sleep for a long time or until the request context is canceled.
330+
select {
331+
case <-time.After(time.Minute):
332+
return &http.Response{Body: http.NoBody}, nil
333+
case <-req.Context().Done():
334+
return nil, req.Context().Err()
335+
}
336+
})
337+
reqCtx, reqCancel = context.WithCancel(user.InjectOrgID(context.Background(), "foo"))
338+
)
339+
340+
r, err := PrometheusCodec.EncodeRequest(reqCtx, &PrometheusRequest{
341+
Path: "/query_range",
342+
Start: time.Now().Add(time.Hour).Unix(),
343+
End: util.TimeToMillis(time.Now()),
344+
Step: int64(1 * time.Second * time.Millisecond),
345+
Query: `foo`,
346+
})
347+
require.Nil(t, err)
348+
349+
_, err = NewLimitedRoundTripper(downstream, PrometheusCodec, mockLimits{maxQueryParallelism: maxQueryParallelism},
350+
MiddlewareFunc(func(next Handler) Handler {
351+
return HandlerFunc(func(c context.Context, _ Request) (Response, error) {
352+
var wg sync.WaitGroup
353+
354+
// Fire up some work. Each sub-request will either be blocked in the sleep or in the queue
355+
// waiting to be scheduled.
356+
for i := 0; i < maxQueryParallelism+20; i++ {
357+
wg.Add(1)
358+
go func() {
359+
defer wg.Done()
360+
_, _ = next.Do(c, &PrometheusRequest{})
361+
}()
362+
}
363+
364+
// Give it a bit a time to get the first sub-requests running.
365+
time.Sleep(100 * time.Millisecond)
366+
367+
// Cancel the original request context.
368+
reqCancel()
369+
370+
// Wait until all sub-requests have done. We expect all of them to cancel asap,
371+
// so it should take a very short time.
372+
waitStart := time.Now()
373+
wg.Wait()
374+
assert.Less(t, time.Since(waitStart).Milliseconds(), int64(100))
375+
376+
return NewEmptyPrometheusResponse(), nil
377+
})
378+
}),
379+
).RoundTrip(r)
380+
require.NoError(t, err)
381+
}

pkg/querier/queryrange/roundtrip.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ func NewTripperware(
202202
return func(next http.RoundTripper) http.RoundTripper {
203203
// Finally, if the user selected any query range middleware, stitch it in.
204204
if len(queryRangeMiddleware) > 0 {
205-
queryrange := NewRoundTripper(next, codec, queryRangeMiddleware...)
205+
queryrange := NewLimitedRoundTripper(next, codec, limits, queryRangeMiddleware...)
206206
return RoundTripFunc(func(r *http.Request) (*http.Response, error) {
207207
isQueryRange := strings.HasSuffix(r.URL.Path, "/query_range")
208208
op := "query"
@@ -230,20 +230,20 @@ func NewTripperware(
230230
}
231231

232232
type roundTripper struct {
233-
next http.RoundTripper
234233
handler Handler
235234
codec Codec
236235
}
237236

238237
// NewRoundTripper merges a set of middlewares into an handler, then inject it into the `next` roundtripper
239238
// using the codec to translate requests and responses.
240239
func NewRoundTripper(next http.RoundTripper, codec Codec, middlewares ...Middleware) http.RoundTripper {
241-
transport := roundTripper{
242-
next: next,
240+
return roundTripper{
241+
handler: MergeMiddlewares(middlewares...).Wrap(roundTripperHandler{
242+
next: next,
243+
codec: codec,
244+
}),
243245
codec: codec,
244246
}
245-
transport.handler = MergeMiddlewares(middlewares...).Wrap(&transport)
246-
return transport
247247
}
248248

249249
func (q roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
@@ -264,8 +264,15 @@ func (q roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
264264
return q.codec.EncodeResponse(r.Context(), response)
265265
}
266266

267+
// roundTripperHandler is a handler that roundtrips requests to next roundtripper.
268+
// It basically encodes a Request from Handler.Do and decode response from next roundtripper.
269+
type roundTripperHandler struct {
270+
next http.RoundTripper
271+
codec Codec
272+
}
273+
267274
// Do implements Handler.
268-
func (q roundTripper) Do(ctx context.Context, r Request) (Response, error) {
275+
func (q roundTripperHandler) Do(ctx context.Context, r Request) (Response, error) {
269276
request, err := q.codec.EncodeRequest(ctx, r)
270277
if err != nil {
271278
return nil, err

0 commit comments

Comments
 (0)