Skip to content

querymiddleware: Pool snappy writer in shard activity series #7308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions pkg/frontend/querymiddleware/shard_active_series.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ var (
}
)

var snappyWriterPool sync.Pool

func getSnappyWriter(w io.Writer) *s2.Writer {
sw := snappyWriterPool.Get()
if sw == nil {
return s2.NewWriter(w)
}
enc := sw.(*s2.Writer)
enc.Reset(w)
return enc
}

type shardActiveSeriesMiddleware struct {
upstream http.RoundTripper
limits Limits
Expand Down Expand Up @@ -257,7 +269,7 @@ func shardedSelector(shardCount, currentShard int, expr parser.Expr) (parser.Exp
}, nil
}

func (s *shardActiveSeriesMiddleware) mergeResponses(ctx context.Context, responses []*http.Response, acceptEncoding string) *http.Response {
func (s *shardActiveSeriesMiddleware) mergeResponses(ctx context.Context, responses []*http.Response, encoding string) *http.Response {
reader, writer := io.Pipe()

items := make(chan *labels.Builder, len(responses))
Expand Down Expand Up @@ -326,29 +338,34 @@ func (s *shardActiveSeriesMiddleware) mergeResponses(ctx context.Context, respon
close(items)
}()

response := &http.Response{Body: reader, StatusCode: http.StatusOK, Header: http.Header{}}
response.Header.Set("Content-Type", "application/json")
if acceptEncoding == encodingTypeSnappyFramed {
response.Header.Set("Content-Encoding", encodingTypeSnappyFramed)
resp := &http.Response{Body: reader, StatusCode: http.StatusOK, Header: http.Header{}}
resp.Header.Set("Content-Type", "application/json")
if encoding == encodingTypeSnappyFramed {
resp.Header.Set("Content-Encoding", encodingTypeSnappyFramed)
}

go s.writeMergedResponse(ctx, g.Wait, writer, items, acceptEncoding)
go s.writeMergedResponse(ctx, g.Wait, writer, items, encoding)

return response
return resp
}

func (s *shardActiveSeriesMiddleware) writeMergedResponse(ctx context.Context, check func() error, w io.WriteCloser, items <-chan *labels.Builder, encodingType string) {
func (s *shardActiveSeriesMiddleware) writeMergedResponse(ctx context.Context, check func() error, w io.WriteCloser, items <-chan *labels.Builder, encoding string) {
defer w.Close()

span, _ := opentracing.StartSpanFromContext(ctx, "shardActiveSeries.writeMergedResponse")
defer span.Finish()

var out io.Writer = w
if encodingType == encodingTypeSnappyFramed {
if encoding == encodingTypeSnappyFramed {
span.LogFields(otlog.String("encoding", encodingTypeSnappyFramed))
enc := s2.NewWriter(w)
defer enc.Close()
enc := getSnappyWriter(w)
out = enc
defer func() {
enc.Close()
// Reset the encoder before putting it back to pool to avoid it to hold the writer.
enc.Reset(nil)
snappyWriterPool.Put(enc)
}()
} else {
span.LogFields(otlog.String("encoding", "none"))
}
Expand Down
87 changes: 80 additions & 7 deletions pkg/frontend/querymiddleware/shard_active_series_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/url"
"strconv"
"strings"
"sync"
"testing"

"github.com/go-kit/log"
Expand Down Expand Up @@ -270,12 +271,6 @@ func Test_shardActiveSeriesMiddleware_RoundTrip(t *testing.T) {
// Stub upstream with valid or invalid responses.
var requestCount atomic.Int32
upstream := RoundTripFunc(func(r *http.Request) (*http.Response, error) {
defer func(body io.ReadCloser) {
if body != nil {
_ = body.Close()
}
}(r.Body)

_, _, err := user.ExtractOrgIDFromHTTPRequest(r)
require.NoError(t, err)
_, err = user.ExtractOrgID(r.Context())
Expand Down Expand Up @@ -358,7 +353,85 @@ func Test_shardActiveSeriesMiddleware_RoundTrip(t *testing.T) {
}
}

func Test_shardActiveSeriesMiddleware_RoundTrip_concurrent(t *testing.T) {
const shardCount = 4

upstream := RoundTripFunc(func(r *http.Request) (*http.Response, error) {
require.NoError(t, r.ParseForm())
req, err := cardinality.DecodeActiveSeriesRequestFromValues(r.Form)
require.NoError(t, err)
shard, _, err := sharding.ShardFromMatchers(req.Matchers)
require.NoError(t, err)
require.NotNil(t, shard)

resp := fmt.Sprintf(`{"data": [{"__name__": "metric-%d"}]}`, shard.ShardIndex)

return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(resp))}, nil
})

s := newShardActiveSeriesMiddleware(
upstream,
mockLimits{maxShardedQueries: shardCount, totalShards: shardCount},
log.NewNopLogger(),
)

assertRoundTrip := func(t *testing.T, trip http.RoundTripper, req *http.Request) {
resp, err := trip.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusOK, resp.StatusCode)

var body io.Reader = resp.Body
if resp.Header.Get("Content-Encoding") == encodingTypeSnappyFramed {
body = s2.NewReader(resp.Body)
}

// For this test, if we can decode the response, it is enough to guaranty it worked. We proof actual validity
// of all kinds of responses in the tests above.
var res result
err = json.NewDecoder(body).Decode(&res)
require.NoError(t, err)
require.Len(t, res.Data, shardCount)
}

const reqCount = 20

var wg sync.WaitGroup
defer wg.Wait()

wg.Add(reqCount)

for n := reqCount; n > 0; n-- {
go func(n int) {
defer wg.Done()

req := httptest.NewRequest("POST", "/active_series", strings.NewReader(`selector={__name__=~"metric-.*"}`))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")

// Send every other request as snappy to proof the middleware doesn't mess up body encoders
if n%2 == 0 {
req.Header.Add("Accept-Encoding", encodingTypeSnappyFramed)
}

req = req.WithContext(user.InjectOrgID(req.Context(), "test"))

assertRoundTrip(t, s, req)
}(n)
}
}

func BenchmarkActiveSeriesMiddlewareMergeResponses(b *testing.B) {
b.Run("encoding=none", func(b *testing.B) {
benchmarkActiveSeriesMiddlewareMergeResponses(b, "")
})

b.Run("encoding=snappy", func(b *testing.B) {
benchmarkActiveSeriesMiddlewareMergeResponses(b, encodingTypeSnappyFramed)
})
}

func benchmarkActiveSeriesMiddlewareMergeResponses(b *testing.B, encoding string) {
type activeSeriesResponse struct {
Data []labels.Labels `json:"data"`
}
Expand Down Expand Up @@ -392,7 +465,7 @@ func BenchmarkActiveSeriesMiddlewareMergeResponses(b *testing.B) {
b.ReportAllocs()

for i := 0; i < b.N; i++ {
resp := s.mergeResponses(context.Background(), benchResponses[i], "")
resp := s.mergeResponses(context.Background(), benchResponses[i], encoding)

_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
Expand Down