Skip to content

Commit

Permalink
receive: rework splitting req into multiple
Browse files Browse the repository at this point in the history
Tenant is hashed together with the time series labels so we need to take
it into account. Move the functionality into the distribute function.

Signed-off-by: Giedrius Statkevičius <giedrius.statkevicius@vinted.com>
  • Loading branch information
GiedriusS committed Jan 31, 2024
1 parent 473cf7a commit ad1c546
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 22 deletions.
63 changes: 41 additions & 22 deletions pkg/receive/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ import (
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/common/model"
"github.com/prometheus/common/route"
"github.com/prometheus/prometheus/model/labels"
"github.com/prometheus/prometheus/model/relabel"
"github.com/prometheus/prometheus/storage"
"github.com/prometheus/prometheus/tsdb"
Expand Down Expand Up @@ -64,6 +66,8 @@ const (
// Labels for metrics.
labelSuccess = "success"
labelError = "error"

metaLabelTenantID = model.MetaLabelPrefix + "tenant_id"
)

var (
Expand Down Expand Up @@ -458,15 +462,15 @@ func (h *Handler) receiveHTTP(w http.ResponseWriter, r *http.Request) {
span.SetTag("receiver.mode", string(h.receiverMode))
defer span.Finish()

tenant, err := tenancy.GetTenantFromHTTP(r, h.options.TenantHeader, h.options.DefaultTenantID, h.options.TenantField)
tenantHTTP, err := tenancy.GetTenantFromHTTP(r, h.options.TenantHeader, h.options.DefaultTenantID, h.options.TenantField)
if err != nil {
level.Error(h.logger).Log("msg", "error getting tenant from HTTP", "err", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

tLogger := log.With(h.logger, "tenant", tenant)
span.SetTag("tenant", tenant)
tLogger := log.With(h.logger, "tenantHTTP", tenantHTTP)
span.SetTag("tenantHTTP", tenantHTTP)

writeGate := h.Limiter.WriteGate()
tracing.DoInSpan(r.Context(), "receive_write_gate_ismyturn", func(ctx context.Context) {
Expand All @@ -479,7 +483,7 @@ func (h *Handler) receiveHTTP(w http.ResponseWriter, r *http.Request) {
return
}

under, err := h.Limiter.HeadSeriesLimiter().isUnderLimit(tenant)
under, err := h.Limiter.HeadSeriesLimiter().isUnderLimit(tenantHTTP)
if err != nil {
level.Error(tLogger).Log("msg", "error while limiting", "err", err.Error())
}
Expand All @@ -495,7 +499,7 @@ func (h *Handler) receiveHTTP(w http.ResponseWriter, r *http.Request) {
// Since this is receive hot path, grow upfront saving allocations and CPU time.
compressed := bytes.Buffer{}
if r.ContentLength >= 0 {
if !requestLimiter.AllowSizeBytes(tenant, r.ContentLength) {
if !requestLimiter.AllowSizeBytes(tenantHTTP, r.ContentLength) {
http.Error(w, "write request too large", http.StatusRequestEntityTooLarge)
return
}
Expand All @@ -515,7 +519,7 @@ func (h *Handler) receiveHTTP(w http.ResponseWriter, r *http.Request) {
return
}

if !requestLimiter.AllowSizeBytes(tenant, int64(len(reqBuf))) {
if !requestLimiter.AllowSizeBytes(tenantHTTP, int64(len(reqBuf))) {
http.Error(w, "write request too large", http.StatusRequestEntityTooLarge)
return
}
Expand Down Expand Up @@ -551,7 +555,7 @@ func (h *Handler) receiveHTTP(w http.ResponseWriter, r *http.Request) {
return
}

if !requestLimiter.AllowSeries(tenant, int64(len(wreq.Timeseries))) {
if !requestLimiter.AllowSeries(tenantHTTP, int64(len(wreq.Timeseries))) {
http.Error(w, "too many timeseries", http.StatusRequestEntityTooLarge)
return
}
Expand All @@ -560,7 +564,7 @@ func (h *Handler) receiveHTTP(w http.ResponseWriter, r *http.Request) {
for _, timeseries := range wreq.Timeseries {
totalSamples += len(timeseries.Samples)
}
if !requestLimiter.AllowSamples(tenant, int64(totalSamples)) {
if !requestLimiter.AllowSamples(tenantHTTP, int64(totalSamples)) {
http.Error(w, "too many samples", http.StatusRequestEntityTooLarge)
return
}
Expand All @@ -573,7 +577,7 @@ func (h *Handler) receiveHTTP(w http.ResponseWriter, r *http.Request) {
}

responseStatusCode := http.StatusOK
if err := h.handleRequest(ctx, rep, tenant, &wreq); err != nil {
if err := h.handleRequest(ctx, rep, tenantHTTP, &wreq); err != nil {
level.Debug(tLogger).Log("msg", "failed to handle request", "err", err.Error())
switch errors.Cause(err) {
case errNotReady:
Expand All @@ -590,12 +594,12 @@ func (h *Handler) receiveHTTP(w http.ResponseWriter, r *http.Request) {
}
http.Error(w, err.Error(), responseStatusCode)
}
h.writeTimeseriesTotal.WithLabelValues(strconv.Itoa(responseStatusCode), tenant).Observe(float64(len(wreq.Timeseries)))
h.writeSamplesTotal.WithLabelValues(strconv.Itoa(responseStatusCode), tenant).Observe(float64(totalSamples))
h.writeTimeseriesTotal.WithLabelValues(strconv.Itoa(responseStatusCode), tenantHTTP).Observe(float64(len(wreq.Timeseries)))
h.writeSamplesTotal.WithLabelValues(strconv.Itoa(responseStatusCode), tenantHTTP).Observe(float64(totalSamples))
}

func (h *Handler) handleRequest(ctx context.Context, rep uint64, tenant string, wreq *prompb.WriteRequest) error {
tLogger := log.With(h.logger, "tenant", tenant)
func (h *Handler) handleRequest(ctx context.Context, rep uint64, tenantHTTP string, wreq *prompb.WriteRequest) error {
tLogger := log.With(h.logger, "tenantHTTP", tenantHTTP)

// This replica value is used to detect cycles in cyclic topologies.
// A non-zero value indicates that the request has already been replicated by a previous receive instance.
Expand Down Expand Up @@ -623,7 +627,7 @@ func (h *Handler) handleRequest(ctx context.Context, rep uint64, tenant string,
// Forward any time series as necessary. All time series
// destined for the local node will be written to the receiver.
// Time series will be replicated as necessary.
return h.forward(ctx, tenant, r, wreq)
return h.forward(ctx, tenantHTTP, r, wreq)
}

// forward accepts a write request, batches its time series by
Expand All @@ -634,7 +638,7 @@ func (h *Handler) handleRequest(ctx context.Context, rep uint64, tenant string,
// unless the request needs to be replicated.
// The function only returns when all requests have finished
// or the context is canceled.
func (h *Handler) forward(ctx context.Context, tenant string, r replica, wreq *prompb.WriteRequest) error {
func (h *Handler) forward(ctx context.Context, tenantHTTP string, r replica, wreq *prompb.WriteRequest) error {
span, ctx := tracing.StartSpan(ctx, "receive_fanout_forward")
defer span.Finish()

Expand All @@ -656,7 +660,7 @@ func (h *Handler) forward(ctx context.Context, tenant string, r replica, wreq *p
}

params := remoteWriteParams{
tenant: tenant,
tenantHTTP: tenantHTTP,
writeRequest: wreq,
replicas: replicas,
alreadyReplicated: r.replicated,
Expand All @@ -666,7 +670,7 @@ func (h *Handler) forward(ctx context.Context, tenant string, r replica, wreq *p
}

type remoteWriteParams struct {
tenant string
tenantHTTP string
writeRequest *prompb.WriteRequest
replicas []uint64
alreadyReplicated bool
Expand All @@ -685,13 +689,13 @@ func (h *Handler) fanoutForward(ctx context.Context, params remoteWriteParams) e
}
}()

logTags := []interface{}{"tenant", params.tenant}
logTags := []interface{}{"tenantHTTP", params.tenantHTTP}
if id, ok := middleware.RequestIDFromContext(ctx); ok {
logTags = append(logTags, "request-id", id)
}
requestLogger := log.With(h.logger, logTags...)

localWrites, remoteWrites, err := h.distributeTimeseriesToReplicas(params.tenant, params.replicas, params.writeRequest.Timeseries)
localWrites, remoteWrites, err := h.distributeTimeseriesToReplicas(params.tenantHTTP, params.replicas, params.writeRequest.Timeseries)
if err != nil {
level.Error(requestLogger).Log("msg", "failed to distribute timeseries to replicas", "err", err)
return err
Expand Down Expand Up @@ -765,7 +769,7 @@ func (h *Handler) fanoutForward(ctx context.Context, params remoteWriteParams) e
// The first return value are the series that should be written to the local node. The second return value are the
// series that should be written to remote nodes.
func (h *Handler) distributeTimeseriesToReplicas(
tenant string,
tenantHTTP string,
replicas []uint64,
timeseries []prompb.TimeSeries,
) (map[endpointReplica]trackedSeries, map[endpointReplica]trackedSeries, error) {
Expand All @@ -774,6 +778,21 @@ func (h *Handler) distributeTimeseriesToReplicas(
remoteWrites := make(map[endpointReplica]trackedSeries)
localWrites := make(map[endpointReplica]trackedSeries)
for tsIndex, ts := range timeseries {
var tenant = tenantHTTP

lbls := labelpb.ZLabelsToPromLabels(ts.Labels)

tenantLabel := lbls.Get(metaLabelTenantID)
if tenantLabel != "" {
tenant = tenantLabel

newLabels := labels.NewBuilder(lbls)
newLabels.Del(metaLabelTenantID)

ts.Labels = labelpb.ZLabelsFromPromLabels(
newLabels.Labels(),
)
}
for _, rn := range replicas {
endpoint, err := h.hashring.GetN(tenant, &ts, rn)
if err != nil {
Expand Down Expand Up @@ -812,13 +831,13 @@ func (h *Handler) sendWrites(
// Do the writes to the local node first. This should be easy and fast.
for writeDestination := range localWrites {
func(writeDestination endpointReplica) {
h.sendLocalWrite(ctx, writeDestination, params.tenant, localWrites[writeDestination], responses)
h.sendLocalWrite(ctx, writeDestination, params.tenantHTTP, localWrites[writeDestination], responses)
}(writeDestination)
}

// Do the writes to remote nodes. Run them all in parallel.
for writeDestination := range remoteWrites {
h.sendRemoteWrite(ctx, params.tenant, writeDestination, remoteWrites[writeDestination], params.alreadyReplicated, responses, wg)
h.sendRemoteWrite(ctx, params.tenantHTTP, writeDestination, remoteWrites[writeDestination], params.alreadyReplicated, responses, wg)
}
}

Expand Down
45 changes: 45 additions & 0 deletions pkg/receive/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/prometheus/prometheus/model/relabel"
"github.com/prometheus/prometheus/storage"
"github.com/prometheus/prometheus/tsdb"
"github.com/stretchr/testify/require"

"github.com/efficientgo/core/testutil"

Expand Down Expand Up @@ -1664,3 +1665,47 @@ func TestHandlerEarlyStop(t *testing.T) {
testutil.NotOk(t, err)
testutil.Equals(t, "http: Server closed", err.Error())
}

type hashringSeenTenants struct {
Hashring

seenTenants map[string]struct{}
}

func (h *hashringSeenTenants) GetN(tenant string, ts *prompb.TimeSeries, n uint64) (string, error) {
if h.seenTenants == nil {
h.seenTenants = map[string]struct{}{}
}
h.seenTenants[tenant] = struct{}{}
return h.Hashring.GetN(tenant, ts, n)
}

func TestDistributeSeries(t *testing.T) {
h := NewHandler(nil, &Options{})

hashring, err := newSimpleHashring([]Endpoint{
{
Address: "http://localhost:9090",
},
})
require.NoError(t, err)
hr := &hashringSeenTenants{Hashring: hashring}
h.Hashring(hr)

_, remote, err := h.distributeTimeseriesToReplicas(
"foo",
[]uint64{0},
[]prompb.TimeSeries{
{
Labels: labelpb.ZLabelsFromPromLabels(labels.FromStrings("a", "b", metaLabelTenantID, "bar")),
},
{
Labels: labelpb.ZLabelsFromPromLabels(labels.FromStrings("b", "a", metaLabelTenantID, "boo")),
},
},
)
require.NoError(t, err)
require.Equal(t, 1, labelpb.ZLabelsToPromLabels(remote[endpointReplica{endpoint: "http://localhost:9090", replica: 0}].timeSeries[0].Labels).Len())
require.Equal(t, 1, labelpb.ZLabelsToPromLabels(remote[endpointReplica{endpoint: "http://localhost:9090", replica: 0}].timeSeries[1].Labels).Len())
require.Equal(t, map[string]struct{}{"bar": {}, "boo": {}}, hr.seenTenants)
}

0 comments on commit ad1c546

Please sign in to comment.