Skip to content

Commit

Permalink
Fix RateLimiter isn't thread safe (goravel#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
hwbrzzl authored Apr 2, 2023
1 parent 915c494 commit ff82643
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 77 deletions.
9 changes: 5 additions & 4 deletions cache/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/patrickmn/go-cache"
"github.com/spf13/cast"

contractscache "github.com/goravel/framework/contracts/cache"
)
Expand Down Expand Up @@ -91,31 +92,31 @@ func (r *Memory) GetBool(key string, def ...bool) bool {
}
res := r.Get(key, def[0])

return res.(bool)
return cast.ToBool(res)
}

func (r *Memory) GetInt(key string, def ...int) int {
if len(def) == 0 {
def = append(def, 0)
}

return r.Get(key, def[0]).(int)
return cast.ToInt(r.Get(key, def[0]))
}

func (r *Memory) GetInt64(key string, def ...int64) int64 {
if len(def) == 0 {
def = append(def, 0)
}

return r.Get(key, def[0]).(int64)
return cast.ToInt64(r.Get(key, def[0]))
}

func (r *Memory) GetString(key string, def ...string) string {
if len(def) == 0 {
def = append(def, "")
}

return r.Get(key, def[0]).(string)
return cast.ToString(r.Get(key, def[0]))
}

//Has Check an item exists in the cache.
Expand Down
9 changes: 5 additions & 4 deletions cache/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/go-redis/redis/v8"
"github.com/pkg/errors"
"github.com/spf13/cast"

"github.com/goravel/framework/contracts/cache"
"github.com/goravel/framework/facades"
Expand Down Expand Up @@ -118,7 +119,7 @@ func (r *Redis) GetBool(key string, def ...bool) bool {
return val == "1"
}

return res.(bool)
return cast.ToBool(res)
}

func (r *Redis) GetInt(key string, def ...int) int {
Expand All @@ -135,7 +136,7 @@ func (r *Redis) GetInt(key string, def ...int) int {
return i
}

return res.(int)
return cast.ToInt(res)
}

func (r *Redis) GetInt64(key string, def ...int64) int64 {
Expand All @@ -152,14 +153,14 @@ func (r *Redis) GetInt64(key string, def ...int64) int64 {
return i
}

return res.(int64)
return cast.ToInt64(res)
}

func (r *Redis) GetString(key string, def ...string) string {
if len(def) == 0 {
def = append(def, "")
}
return r.Get(key, def[0]).(string)
return cast.ToString(r.Get(key, def[0]))
}

//Has Check an item exists in the cache.
Expand Down
63 changes: 37 additions & 26 deletions http/middleware/throttle.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,67 +3,62 @@ package middleware
import (
"crypto/md5"
"encoding/hex"
"net/http"
"fmt"
"time"

"github.com/spf13/cast"

contractshttp "github.com/goravel/framework/contracts/http"
"github.com/goravel/framework/contracts/http"
"github.com/goravel/framework/facades"
httplimit "github.com/goravel/framework/http/limit"
supporttime "github.com/goravel/framework/support/time"
)

func Throttle(name string) contractshttp.Middleware {
return func(ctx contractshttp.Context) {
func Throttle(name string) http.Middleware {
return func(ctx http.Context) {
if limiter := facades.RateLimiter.Limiter(name); limiter != nil {
if limits := limiter(ctx); len(limits) > 0 {
for _, limit := range limits {
if instance, ok := limit.(*httplimit.Limit); ok {
// if no key is set, use the path and ip address as the default key
if len(instance.Key) == 0 {
hash := md5.Sum([]byte(ctx.Request().Path()))
instance.Key = facades.Config.GetString("cache.prefix") + ":throttle:" + name + ":" + hex.EncodeToString(hash[:]) + ":" + ctx.Request().Ip()
} else {
hash := md5.Sum([]byte(instance.Key))
instance.Key = facades.Config.GetString("cache.prefix") + ":throttle:" + name + ":" + hex.EncodeToString(hash[:])
}
key, timer := key(ctx, name, instance)
currentTimes := 1

// check if the timer exists in the cache
if facades.Cache.Has(instance.Key + ":timer") {
// if the timer exists, check if the number of attempts is greater than the max attempts
value := facades.Cache.GetInt(instance.Key, 0)
if facades.Cache.Has(timer) {
value := facades.Cache.GetInt(key, 0)
if value >= instance.MaxAttempts {
// add the retry headers to the response
ctx.Response().Header("X-RateLimit-Reset", cast.ToString(cast.ToInt(facades.Cache.Get(instance.Key+":timer", 0))+instance.DecayMinutes*60))
ctx.Response().Header("Retry-After", cast.ToString(cast.ToInt(facades.Cache.Get(instance.Key+":timer", 0))+instance.DecayMinutes*60-int(supporttime.Now().Unix())))
expireSecond := facades.Cache.GetInt(timer, 0) + instance.DecayMinutes*60
ctx.Response().Header("X-RateLimit-Reset", cast.ToString(expireSecond))
ctx.Response().Header("Retry-After", cast.ToString(expireSecond-int(supporttime.Now().Unix())))
if instance.ResponseCallback != nil {
instance.ResponseCallback(ctx)
return
} else {
ctx.Request().AbortWithStatus(http.StatusTooManyRequests)
return
}
} else {
// TODO: change Put to Increment in the future
err := facades.Cache.Put(instance.Key, value+1, time.Duration(instance.DecayMinutes)*time.Minute)
if err != nil {
var err error
if currentTimes, err = facades.Cache.Increment(key); err != nil {
panic(err)
}
}
} else {
// if the timer does not exist, create it and set the number of attempts to 1
err := facades.Cache.Put(instance.Key+":timer", supporttime.Now().Unix(), time.Duration(instance.DecayMinutes)*time.Minute)
expireMinute := time.Duration(instance.DecayMinutes) * time.Minute

err := facades.Cache.Put(timer, supporttime.Now().Unix(), expireMinute)
if err != nil {
panic(err)
}
err = facades.Cache.Put(instance.Key, 1, time.Duration(instance.DecayMinutes)*time.Minute)

err = facades.Cache.Put(key, currentTimes, expireMinute)
if err != nil {
panic(err)
}
}

// add the headers for the passed request
ctx.Response().Header("X-RateLimit-Limit", cast.ToString(instance.MaxAttempts))
ctx.Response().Header("X-RateLimit-Remaining", cast.ToString(instance.MaxAttempts-facades.Cache.GetInt(instance.Key, 0)))
ctx.Response().Header("X-RateLimit-Remaining", cast.ToString(instance.MaxAttempts-currentTimes))
}
}
}
Expand All @@ -72,3 +67,19 @@ func Throttle(name string) contractshttp.Middleware {
ctx.Request().Next()
}
}

func key(ctx http.Context, limiter string, limit *httplimit.Limit) (string, string) {
// if no key is set, use the path and ip address as the default key
var key, timer string
prefix := facades.Config.GetString("cache.prefix")
if len(limit.Key) == 0 {
hash := md5.Sum([]byte(ctx.Request().Path()))
key = fmt.Sprintf("%s:throttle:%s:%s:%s", prefix, limiter, hex.EncodeToString(hash[:]), ctx.Request().Ip())
} else {
hash := md5.Sum([]byte(limit.Key))
key = fmt.Sprintf("%s:throttle:%s:%s", prefix, limiter, hex.EncodeToString(hash[:]))
}
timer = key + ":timer"

return key, timer
}
76 changes: 33 additions & 43 deletions route/gin_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,49 +354,6 @@ func TestGinGroup(t *testing.T) {
expectCode: http.StatusOK,
expectBody: "{\"global\":\"goravel\"}",
},
{
name: "Throttle Middleware Passed",
setup: func(req *http.Request) {
mockConfig.On("GetString", "cache.stores.memory.driver").Return("memory").Once()
mockConfig.On("GetString", "cache.prefix").Return("throttle").Twice()

facades.Cache = frameworkcache.NewApplication("memory")
facades.RateLimiter = frameworkhttp.NewRateLimiter()
facades.RateLimiter.For("test", func(ctx httpcontract.Context) httpcontract.Limit {
return limit.PerMinute(1)
})

gin.GlobalMiddleware(middleware.Throttle("test"))
gin.Any("/throttle/{id}", func(ctx httpcontract.Context) {
ctx.Response().Success().Json(httpcontract.Json{
"id": ctx.Request().Input("id"),
})
})
req.Header.Set("Origin", "http://127.0.0.1")
req.Header.Set("Access-Control-Request-Method", "GET")
},
method: "GET",
url: "/throttle/1",
expectCode: http.StatusOK,
},
{
name: "Throttle Middleware Blocked",
setup: func(req *http.Request) {
mockConfig.On("GetString", "cache.prefix").Return("throttle").Twice()

gin.GlobalMiddleware(middleware.Throttle("test"))
gin.Any("/throttle/{id}", func(ctx httpcontract.Context) {
ctx.Response().Success().Json(httpcontract.Json{
"id": ctx.Request().Input("id"),
})
})
req.Header.Set("Origin", "http://127.0.0.1")
req.Header.Set("Access-Control-Request-Method", "GET")
},
method: "GET",
url: "/throttle/1",
expectCode: http.StatusTooManyRequests,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
Expand All @@ -416,6 +373,39 @@ func TestGinGroup(t *testing.T) {
}
}

func TestThrottle(t *testing.T) {
mockConfig := mock.Config()
mockConfig.On("GetBool", "app.debug").Return(true).Once()
mockConfig.On("GetString", "cache.stores.memory.driver").Return("memory").Once()
mockConfig.On("GetString", "cache.prefix").Return("throttle").Times(3)

facades.Cache = frameworkcache.NewApplication("memory")
facades.RateLimiter = frameworkhttp.NewRateLimiter()
facades.RateLimiter.For("test", func(ctx httpcontract.Context) httpcontract.Limit {
return limit.PerMinute(1)
})

gin := NewGin()
gin.GlobalMiddleware(middleware.Throttle("test"))
gin.Get("/throttle/{id}", func(ctx httpcontract.Context) {
ctx.Response().Success().Json(httpcontract.Json{
"id": ctx.Request().Input("id"),
})
})

w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/throttle/1", nil)
gin.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)

w1 := httptest.NewRecorder()
req1, _ := http.NewRequest("GET", "/throttle/1", nil)
gin.ServeHTTP(w1, req1)
assert.Equal(t, http.StatusTooManyRequests, w1.Code)

mockConfig.AssertExpectations(t)
}

func abortMiddleware() httpcontract.Middleware {
return func(ctx httpcontract.Context) {
ctx.Request().AbortWithStatus(http.StatusNonAuthoritativeInfo)
Expand Down

0 comments on commit ff82643

Please sign in to comment.