Skip to content

Commit

Permalink
Rate limits: replace redis SET with INCRBY (#7782)
Browse files Browse the repository at this point in the history
Add a new method, `BatchIncrement`, to issue `IncrBy` (instead of `Set`)
to Redis. This helps prevent the race condition that allows bursts of
near-simultaneous requests to, effectively, spend the same token.

Call this new method when incrementing an existing key. New keys still
need to use `BatchSet` because Redis doesn't have a facility to, within
a single operation, increment _or_ set a default value if none exists.

Add a new feature flag, `IncrementRateLimits`, gating the use of this
new method.

CPS Compliance Review: This feature flag does not change any behaviour
that is described or constrained by our CP/CPS. The closest relation
would just be API availability in general.

Fixes #7780
  • Loading branch information
jprenken authored Nov 4, 2024
1 parent 2d69d7b commit 4adc65f
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 39 deletions.
5 changes: 5 additions & 0 deletions features/features.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ type Config struct {
// get the AUTO_INCREMENT ID of each new authz without relying on MariaDB's
// unique "INSERT ... RETURNING" functionality.
InsertAuthzsIndividually bool

// IncrementRateLimits uses Redis' IncrBy, instead of Set, for rate limit
// accounting. This catches and denies spikes of requests much more
// reliably.
IncrementRateLimits bool
}

var fMu = new(sync.RWMutex)
Expand Down
68 changes: 56 additions & 12 deletions ratelimits/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/prometheus/client_golang/prometheus"

berrors "github.com/letsencrypt/boulder/errors"
"github.com/letsencrypt/boulder/features"
)

const (
Expand Down Expand Up @@ -274,11 +275,13 @@ func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision
}
batchDecision := allowedDecision
newTATs := make(map[string]time.Time)
newBuckets := make(map[string]time.Time)
incrBuckets := make(map[string]increment)
txnOutcomes := make(map[Transaction]string)

for _, txn := range batch {
tat, exists := tats[txn.bucketKey]
if !exists {
tat, bucketExists := tats[txn.bucketKey]
if !bucketExists {
// First request from this client.
tat = l.clk.Now()
}
Expand All @@ -293,6 +296,15 @@ func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision
if d.allowed && (tat != d.newTAT) && txn.spend {
// New bucket state should be persisted.
newTATs[txn.bucketKey] = d.newTAT

if bucketExists {
incrBuckets[txn.bucketKey] = increment{
cost: time.Duration(txn.cost * txn.limit.emissionInterval),
ttl: time.Duration(txn.limit.burstOffset),
}
} else {
newBuckets[txn.bucketKey] = d.newTAT
}
}

if !txn.spendOnly() {
Expand All @@ -307,10 +319,28 @@ func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision
}
}

if batchDecision.allowed && len(newTATs) > 0 {
err = l.source.BatchSet(ctx, newTATs)
if err != nil {
return nil, err
if features.Get().IncrementRateLimits {
if batchDecision.allowed {
if len(newBuckets) > 0 {
err = l.source.BatchSet(ctx, newBuckets)
if err != nil {
return nil, err
}
}

if len(incrBuckets) > 0 {
err = l.source.BatchIncrement(ctx, incrBuckets)
if err != nil {
return nil, err
}
}
}
} else {
if batchDecision.allowed && len(newTATs) > 0 {
err = l.source.BatchSet(ctx, newTATs)
if err != nil {
return nil, err
}
}
}

Expand Down Expand Up @@ -365,10 +395,11 @@ func (l *Limiter) BatchRefund(ctx context.Context, txns []Transaction) (*Decisio

batchDecision := allowedDecision
newTATs := make(map[string]time.Time)
incrBuckets := make(map[string]increment)

for _, txn := range batch {
tat, exists := tats[txn.bucketKey]
if !exists {
tat, bucketExists := tats[txn.bucketKey]
if !bucketExists {
// Ignore non-existent bucket.
continue
}
Expand All @@ -382,13 +413,26 @@ func (l *Limiter) BatchRefund(ctx context.Context, txns []Transaction) (*Decisio
if d.allowed && tat != d.newTAT {
// New bucket state should be persisted.
newTATs[txn.bucketKey] = d.newTAT
incrBuckets[txn.bucketKey] = increment{
cost: time.Duration(-txn.cost * txn.limit.emissionInterval),
ttl: time.Duration(txn.limit.burstOffset),
}
}
}

if len(newTATs) > 0 {
err = l.source.BatchSet(ctx, newTATs)
if err != nil {
return nil, err
if features.Get().IncrementRateLimits {
if len(incrBuckets) > 0 {
err = l.source.BatchIncrement(ctx, incrBuckets)
if err != nil {
return nil, err
}
}
} else {
if len(newTATs) > 0 {
err = l.source.BatchSet(ctx, newTATs)
if err != nil {
return nil, err
}
}
}
return batchDecision, nil
Expand Down
19 changes: 17 additions & 2 deletions ratelimits/limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"math/rand/v2"
"net"
"os"
"testing"
"time"

Expand All @@ -12,6 +13,7 @@ import (

"github.com/letsencrypt/boulder/config"
berrors "github.com/letsencrypt/boulder/errors"
"github.com/letsencrypt/boulder/features"
"github.com/letsencrypt/boulder/metrics"
"github.com/letsencrypt/boulder/test"
)
Expand All @@ -38,6 +40,19 @@ func newTestTransactionBuilder(t *testing.T) *TransactionBuilder {
}

func setup(t *testing.T) (context.Context, map[string]*Limiter, *TransactionBuilder, clock.FakeClock, string) {
// Because all test cases in this file are affected by this feature flag, we
// want to run them all both with and without the feature flag. This way, we
// get one set of runs with and one set without. It's difficult to defer
// features.Reset() from the setup func (these tests are parallel); as long
// as this code doesn't test any other features, we don't need to.
//
// N.b. This is fragile. If a test case does call features.Reset(), it will
// not be testing the intended code path. But we expect to clean this up
// quickly.
if os.Getenv("BOULDER_CONFIG_DIR") == "test/config-next" {
features.Set(features.Config{IncrementRateLimits: true})
}

testCtx := context.Background()
clk := clock.NewFake()

Expand Down Expand Up @@ -304,8 +319,8 @@ func TestLimiter_InitializationViaCheckAndSpend(t *testing.T) {
test.AssertEquals(t, d.resetIn, time.Millisecond*50)
test.AssertEquals(t, d.retryIn, time.Duration(0))

// However, that cost should not be spent yet, a 0 cost check should
// tell us that we actually have 19 remaining.
// And that cost should have been spent; a 0 cost check should still
// tell us that we have 19 remaining.
d, err = l.Check(testCtx, txn0)
test.AssertNotError(t, err, "should not error")
test.Assert(t, d.allowed, "should be allowed")
Expand Down
25 changes: 24 additions & 1 deletion ratelimits/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ type source interface {
// the underlying storage client implementation).
BatchSet(ctx context.Context, bucketKeys map[string]time.Time) error

// BatchIncrement updates the TATs for the specified bucketKeys, similar to
// BatchSet. Implementations MUST ensure non-blocking operations by either:
// a) applying a deadline or timeout to the context WITHIN the method, or
// b) guaranteeing the operation will not block indefinitely (e.g. via
// the underlying storage client implementation).
BatchIncrement(ctx context.Context, buckets map[string]increment) error

// Get retrieves the TAT associated with the specified bucketKey (formatted
// as 'name:id'). Implementations MUST ensure non-blocking operations by
// either:
Expand All @@ -45,13 +52,20 @@ type source interface {
Delete(ctx context.Context, bucketKey string) error
}

type increment struct {
cost time.Duration
ttl time.Duration
}

// inmem is an in-memory implementation of the source interface used for
// testing.
type inmem struct {
sync.RWMutex
m map[string]time.Time
}

var _ source = (*inmem)(nil)

func newInmem() *inmem {
return &inmem{m: make(map[string]time.Time)}
}
Expand All @@ -65,6 +79,15 @@ func (in *inmem) BatchSet(_ context.Context, bucketKeys map[string]time.Time) er
return nil
}

func (in *inmem) BatchIncrement(_ context.Context, bucketKeys map[string]increment) error {
in.Lock()
defer in.Unlock()
for k, v := range bucketKeys {
in.m[k] = in.m[k].Add(v.cost)
}
return nil
}

func (in *inmem) Get(_ context.Context, bucketKey string) (time.Time, error) {
in.RLock()
defer in.RUnlock()
Expand All @@ -82,7 +105,7 @@ func (in *inmem) BatchGet(_ context.Context, bucketKeys []string) (map[string]ti
for _, k := range bucketKeys {
tat, ok := in.m[k]
if !ok {
tats[k] = time.Time{}
continue
}
tats[k] = tat
}
Expand Down
44 changes: 34 additions & 10 deletions ratelimits/source_redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ func (r *RedisSource) observeLatency(call string, latency time.Duration, err err

// BatchSet stores TATs at the specified bucketKeys using a pipelined Redis
// Transaction in order to reduce the number of round-trips to each Redis shard.
// An error is returned if the operation failed and nil otherwise.
func (r *RedisSource) BatchSet(ctx context.Context, buckets map[string]time.Time) error {
start := r.clk.Now()

Expand All @@ -109,9 +108,35 @@ func (r *RedisSource) BatchSet(ctx context.Context, buckets map[string]time.Time
return nil
}

// Get retrieves the TAT at the specified bucketKey. An error is returned if the
// operation failed and nil otherwise. If the bucketKey does not exist,
// ErrBucketNotFound is returned.
// BatchIncrement updates TATs for the specified bucketKeys using a pipelined
// Redis Transaction in order to reduce the number of round-trips to each Redis
// shard.
func (r *RedisSource) BatchIncrement(ctx context.Context, buckets map[string]increment) error {
start := r.clk.Now()

pipeline := r.client.Pipeline()
for bucketKey, incr := range buckets {
pipeline.IncrBy(ctx, bucketKey, incr.cost.Nanoseconds())
pipeline.Expire(ctx, bucketKey, incr.ttl)
}
_, err := pipeline.Exec(ctx)
if err != nil {
r.observeLatency("batchincrby", r.clk.Since(start), err)
return err
}

totalLatency := r.clk.Since(start)
perSetLatency := totalLatency / time.Duration(len(buckets))
for range buckets {
r.observeLatency("batchincrby_entry", perSetLatency, nil)
}

r.observeLatency("batchincrby", totalLatency, nil)
return nil
}

// Get retrieves the TAT at the specified bucketKey. If the bucketKey does not
// exist, ErrBucketNotFound is returned.
func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, error) {
start := r.clk.Now()

Expand All @@ -133,8 +158,8 @@ func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, err

// BatchGet retrieves the TATs at the specified bucketKeys using a pipelined
// Redis Transaction in order to reduce the number of round-trips to each Redis
// shard. An error is returned if the operation failed and nil otherwise. If a
// bucketKey does not exist, it WILL NOT be included in the returned map.
// shard. If a bucketKey does not exist, it WILL NOT be included in the returned
// map.
func (r *RedisSource) BatchGet(ctx context.Context, bucketKeys []string) (map[string]time.Time, error) {
start := r.clk.Now()

Expand Down Expand Up @@ -184,9 +209,8 @@ func (r *RedisSource) BatchGet(ctx context.Context, bucketKeys []string) (map[st
return tats, nil
}

// Delete deletes the TAT at the specified bucketKey ('name:id'). It returns an
// error if the operation failed and nil otherwise. A nil return value does not
// indicate that the bucketKey existed.
// Delete deletes the TAT at the specified bucketKey ('name:id'). A nil return
// value does not indicate that the bucketKey existed.
func (r *RedisSource) Delete(ctx context.Context, bucketKey string) error {
start := r.clk.Now()

Expand All @@ -201,7 +225,7 @@ func (r *RedisSource) Delete(ctx context.Context, bucketKey string) error {
}

// Ping checks that each shard of the *redis.Ring is reachable using the PING
// command. It returns an error if any shard is unreachable and nil otherwise.
// command.
func (r *RedisSource) Ping(ctx context.Context) error {
start := r.clk.Now()

Expand Down
29 changes: 20 additions & 9 deletions ratelimits/source_redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,16 @@ func TestRedisSource_BatchSetAndGet(t *testing.T) {
"shard2": "10.33.33.5:4218",
})

now := clk.Now()
val1 := now.Add(time.Second)
val2 := now.Add(time.Second * 2)
val3 := now.Add(time.Second * 3)

set := map[string]time.Time{
"test1": val1,
"test2": val2,
"test3": val3,
"test1": clk.Now().Add(time.Second),
"test2": clk.Now().Add(time.Second * 2),
"test3": clk.Now().Add(time.Second * 3),
}

incr := map[string]increment{
"test1": {time.Second, time.Minute},
"test2": {time.Second * 2, time.Minute},
"test3": {time.Second * 3, time.Minute},
}

err := s.BatchSet(context.Background(), set)
Expand All @@ -95,7 +96,17 @@ func TestRedisSource_BatchSetAndGet(t *testing.T) {
test.AssertNotError(t, err, "BatchGet() should not error")

for k, v := range set {
test.Assert(t, got[k].Equal(v), "BatchGet() should return the values set by BatchSet()")
test.AssertEquals(t, got[k], v)
}

err = s.BatchIncrement(context.Background(), incr)
test.AssertNotError(t, err, "BatchIncrement() should not error")

got, err = s.BatchGet(context.Background(), []string{"test1", "test2", "test3"})
test.AssertNotError(t, err, "BatchGet() should not error")

for k := range set {
test.AssertEquals(t, got[k], set[k].Add(incr[k].cost))
}

// Test that BatchGet() returns a zero time for a key that does not exist.
Expand Down
3 changes: 2 additions & 1 deletion test/config-next/ra.json
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@
},
"features": {
"AsyncFinalize": true,
"UseKvLimitsForNewOrder": true
"UseKvLimitsForNewOrder": true,
"IncrementRateLimits": true
},
"ctLogs": {
"stagger": "500ms",
Expand Down
3 changes: 2 additions & 1 deletion test/config-next/wfe2.json
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@
"features": {
"ServeRenewalInfo": true,
"CheckIdentifiersPaused": true,
"UseKvLimitsForNewOrder": true
"UseKvLimitsForNewOrder": true,
"IncrementRateLimits": true
},
"certProfiles": {
"legacy": "The normal profile you know and love",
Expand Down
Loading

0 comments on commit 4adc65f

Please sign in to comment.