Skip to content

Commit

Permalink
Adjust quota handling of missing or expired keys
Browse files Browse the repository at this point in the history
  • Loading branch information
Tit Petric authored and titpetric committed Sep 10, 2024
1 parent 570b651 commit 3fc577f
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 63 deletions.
30 changes: 18 additions & 12 deletions gateway/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,18 +331,24 @@ func TestSessionLimiter_RedisQuotaExceeded_PerAPI(t *testing.T) {
}
}

// for api1 - per api
sendReqAndCheckQuota(t, apis[0].APIID, 9, true)
sendReqAndCheckQuota(t, apis[0].APIID, 8, true)
sendReqAndCheckQuota(t, apis[0].APIID, 7, true)

// for api2 - per api
sendReqAndCheckQuota(t, apis[1].APIID, 1, true)
sendReqAndCheckQuota(t, apis[1].APIID, 0, true)

// for api3 - global
sendReqAndCheckQuota(t, apis[2].APIID, 24, false)
sendReqAndCheckQuota(t, apis[2].APIID, 23, false)
t.Run("For api1 - per api", func(t *testing.T) {
sendReqAndCheckQuota(t, apis[0].APIID, 9, true)
sendReqAndCheckQuota(t, apis[0].APIID, 8, true)
sendReqAndCheckQuota(t, apis[0].APIID, 7, true)
})

t.Run("For api2 - per api", func(t *testing.T) {
sendReqAndCheckQuota(t, apis[1].APIID, 1, true)
sendReqAndCheckQuota(t, apis[1].APIID, 0, true)
})

t.Run("For api3 - global", func(t *testing.T) {
sendReqAndCheckQuota(t, apis[2].APIID, 24, false)
sendReqAndCheckQuota(t, apis[2].APIID, 23, false)
sendReqAndCheckQuota(t, apis[2].APIID, 22, false)
sendReqAndCheckQuota(t, apis[2].APIID, 21, false)
sendReqAndCheckQuota(t, apis[2].APIID, 20, false)
})
}

func TestCopyAllowedURLs(t *testing.T) {
Expand Down
135 changes: 84 additions & 51 deletions gateway/session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,97 +333,128 @@ func (l *SessionLimiter) ForwardMessage(r *http.Request, session *user.SessionSt

// RedisQuotaExceeded returns true if the request should be blocked as over quota.
func (l *SessionLimiter) RedisQuotaExceeded(r *http.Request, session *user.SessionState, quotaKey, scope string, limit *user.APILimit, store storage.Handler, hashKeys bool) bool {
logger := log.WithFields(logrus.Fields{
"quotaMax": limit.QuotaMax,
"quotaRenewalRate": limit.QuotaRenewalRate,
})

if limit.QuotaMax <= 0 {
logger.Error("Quota disabled: quota max <= 0")
return false
}

// don't use the requests cancellation context
ctx := context.Background()

session.Touch()

quotaScope := ""
if scope != "" {
quotaScope = scope + "-"
}

key := session.KeyID

if hashKeys {
key = storage.HashStr(session.KeyID)
}

if quotaKey != "" {
key = quotaKey
}

now := time.Now().Truncate(0)

// rawKey is the redis key for quota
rawKey := QuotaKeyPrefix + quotaScope + key
quotaRenewalRate := time.Second * time.Duration(limit.QuotaRenewalRate)
quotaMax := limit.QuotaMax

// First, ensure a distributed lock
quotaRenewalRate := time.Second * time.Duration(limit.QuotaRenewalRate)
conn := l.limiterStorage
locker := limiter.NewLimiter(conn).Locker(rawKey)

if err := locker.Lock(ctx); err != nil {
log.WithError(err).Error("error locking quota key, blocking")
var expired, exists bool
var expiredAt time.Time

//logger = logger.WithField("key", rawKey)

dur, err := conn.PTTL(ctx, rawKey).Result()
if err != nil && !errors.Is(err, redis.Nil) {
logger.WithError(err).Error("error getting key TTL, blocking")
return true
}
defer func() {
if err := locker.Unlock(ctx); err != nil {
log.WithError(err).Error("error unlocking quota key")

// The command returns -2 if the key does not exist.
// The command returns -1 if the key exists but has no associated expire.
expired = dur < 0
exists = dur != -2

expiredAt = now.Add(dur)

logger = logger.WithFields(logrus.Fields{
"exists": exists,
"expired": expired,
"rawKey": rawKey,
// "expiredAt": expiredAt,
})

increment := func() bool {
var res *redis.IntCmd
conn.Pipelined(ctx, func(pipe redis.Pipeliner) error {
res = pipe.Incr(ctx, rawKey)
if quotaRenewalRate > 0 {
pipe.ExpireNX(ctx, rawKey, quotaRenewalRate)
} else {
// no expiration time
pipe.ExpireNX(ctx, rawKey, 0)
}
return nil
})
qInt, err := res.Result()
if err != nil {
logger.WithError(err).Error("error incrementing quota key")
return true
}
}()

var expired bool
var expiredAt time.Time
dur, err := conn.PTTL(ctx, rawKey).Result()
if err == nil || errors.Is(err, redis.Nil) {
if err == nil {
expiredAt = now.Add(dur)
} else {
expired = true
expiredAt = now.Add(quotaRenewalRate)
conn.Set(ctx, rawKey, 0, quotaRenewalRate)
blocked := qInt-1 >= limit.QuotaMax
remaining := limit.QuotaMax - qInt
if blocked {
remaining = 0
}
} else {
log.WithError(err).Warn("error getting key TTL, blocking")
return true
}

qInt, err := conn.Incr(ctx, rawKey).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.WithError(err).Error("can't update quota, blocking")
return true
logger = logger.WithField("quota", qInt-1)
logger = logger.WithField("blocked", blocked)
logger = logger.WithField("remaining", remaining)
logger.Debug("[QUOTA] Update quota key")

l.updateSessionQuota(session, scope, remaining, expiredAt.Unix())
return blocked
}

logFields := logrus.Fields{
"quota": qInt - 1,
"quotaMax": quotaMax,
"expired": expired,
"expiredAt": expiredAt,
// If exists and not expired, just increment it.
if exists && !expired {
return increment()
}

log.WithFields(logFields).Debug("[QUOTA] Request")
// if key is expired and can't renew, update the counter and
// block traffic going forward.
if limit.QuotaRenewalRate <= 0 {
return increment()
}

if qInt-1 >= quotaMax {
log.WithFields(logFields).Debug("[QUOTA] Limits reached")
// First, ensure a distributed lock
locker := limiter.NewLimiter(conn).Locker(rawKey)

if expired {
if quotaRenewalRate <= 0 {
return true
}
// Lock the key
if err := locker.Lock(ctx); err != nil {
// Increment the key if lock fails
return increment()
}

go store.DeleteRawKey(rawKey)
qInt = 1
} else {
return true
// Unlock the key when done
defer func() {
if err := locker.Unlock(ctx); err != nil {
logger.WithError(err).Error("error unlocking quota key")
}
}
}()

l.updateSessionQuota(session, scope, quotaMax-qInt, expiredAt.Unix())
return false
// locked: reset quota + increment
conn.Set(ctx, rawKey, 0, quotaRenewalRate)
return increment()
}

func GetAccessDefinitionByAPIIDOrSession(session *user.SessionState, api *APISpec) (accessDef *user.AccessDefinition, allowanceScope string, err error) {
Expand Down Expand Up @@ -473,4 +504,6 @@ func (*SessionLimiter) updateSessionQuota(session *user.SessionState, scope stri
session.QuotaRemaining = remaining
session.QuotaRenews = renews
}

session.Touch()
}

0 comments on commit 3fc577f

Please sign in to comment.