Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions contrib/github.com/go-redis/redis.v5/redis.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package redis

import (
"time"

"github.com/ntindall/speedbump/internal"
redis "gopkg.in/redis.v5"
)

// Wrapper is a wrapper around *redis.Client that implements the
// internal.RedisClient interface.
type Wrapper struct {
*redis.Client
}

var _ internal.RedisClient = &Wrapper{}

func (w *Wrapper) Exists(key string) (exists bool, err error) {
return w.Client.Exists(key).Result()
}

func (w *Wrapper) Get(key string) (value string, err error) {
return w.Client.Get(key).Result()
}

func (w *Wrapper) IncrAndExpire(key string, duration time.Duration) error {
return w.Client.Watch(func(rx *redis.Tx) error {
_, err := rx.Pipelined(func(pipe *redis.Pipeline) error {
if err := pipe.Incr(key).Err(); err != nil {
return err
}

return pipe.Expire(key, duration).Err()
})

return err
})
}

// NewRedisClient constructs a speedbump.RedisClient from a "gopkg.in/redis.v5"
// redis.Client.
func NewRedisClient(redisClient *redis.Client) internal.RedisClient {
return &Wrapper{redisClient}
}
41 changes: 41 additions & 0 deletions contrib/github.com/gomodule/redigo/redis/redis.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package redis

import (
"time"

redis "github.com/gomodule/redigo/redis"
"github.com/ntindall/speedbump/internal"
)

type redisWrapper struct {
conn redis.Conn
}

var _ internal.RedisClient = &redisWrapper{}

func (w *redisWrapper) Exists(key string) (exists bool, err error) {
return redis.Bool(w.conn.Do("EXISTS", key))
}

func (w *redisWrapper) Get(key string) (value string, err error) {
return redis.String(w.conn.Do("GET", key))
}

func (w *redisWrapper) IncrAndExpire(key string, duration time.Duration) error {
if err := w.conn.Send("MULTI"); err != nil {
return err
}
if err := w.conn.Send("INCR", key); err != nil {
return err
}
if err := w.conn.Send("EXPIRE", key, duration/time.Second); err != nil {
return err
}
_, err := w.conn.Do("EXEC")
return err
}

// NewRedisClient constructs a internal.RedisClient from a redigo connection.
func NewRedisClient(redisConn redis.Conn) internal.RedisClient {
return &redisWrapper{conn: redisConn}
}
14 changes: 14 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package internal

import (
"time"
)

// RedisClient is an abstraction over speedbump connection to redis.
// It is exported from internal so that it can only be instructed from
// within the package.
type RedisClient interface {
Get(key string) (value string, err error)
Exists(key string) (exists bool, err error)
IncrAndExpire(key string, duration time.Duration) error
}
32 changes: 13 additions & 19 deletions speedbump.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ import (
"strconv"
"time"

"gopkg.in/redis.v5"
"github.com/ntindall/speedbump/internal"
)

var (
redisNil string = "redis: nil"
)

// RateLimiter is a Redis-backed rate limiter.
type RateLimiter struct {
// redisClient is the client that will be used to talk to the Redis server.
redisClient *redis.Client
redisClient internal.RedisClient
// hasher is used to generate keys for each counter and to set their
// expiration time.
hasher RateHasher
Expand All @@ -36,7 +40,7 @@ type RateHasher interface {

// NewLimiter creates a new instance of a rate limiter.
func NewLimiter(
client *redis.Client,
client internal.RedisClient,
hasher RateHasher,
max int64,
) *RateLimiter {
Expand All @@ -51,18 +55,18 @@ func NewLimiter(
// during the current period.
func (r *RateLimiter) Has(id string) (bool, error) {
hash := r.hasher.Hash(id)
return r.redisClient.Exists(hash).Result()
return r.redisClient.Exists(hash)
}

// Attempted returns the number of attempted requests for an id in the current
// period. Attempted does not count attempts that exceed the max requests in an
// interval and only returns the max count after this is reached.
func (r *RateLimiter) Attempted(id string) (int64, error) {
hash := r.hasher.Hash(id)
val, err := r.redisClient.Get(hash).Result()
val, err := r.redisClient.Get(hash)

if err != nil {
if err == redis.Nil {
if err.Error() == redisNil {
// Key does not exist. See: http://redis.io/commands/GET
return 0, nil
}
Expand Down Expand Up @@ -104,9 +108,9 @@ func (r *RateLimiter) Attempt(id string) (bool, error) {
// exist.
exists := true

val, err := r.redisClient.Get(hash).Result()
val, err := r.redisClient.Get(hash)
if err != nil {
if err == redis.Nil {
if err.Error() == redisNil {
// Key does not exist. See: http://redis.io/commands/GET
exists = false
} else {
Expand All @@ -132,17 +136,7 @@ func (r *RateLimiter) Attempt(id string) (bool, error) {
//
// See: http://redis.io/commands/INCR
// See: http://redis.io/commands/INCR#pattern-rate-limiter-1
err = r.redisClient.Watch(func(rx *redis.Tx) error {
_, err := rx.Pipelined(func(pipe *redis.Pipeline) error {
if err := pipe.Incr(hash).Err(); err != nil {
return err
}

return pipe.Expire(hash, r.hasher.Duration()).Err()
})

return err
})
err = r.redisClient.IncrAndExpire(hash, r.hasher.Duration())

if err != nil {
return false, err
Expand Down
63 changes: 26 additions & 37 deletions speedbump_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package speedbump
package speedbump_test

import (
"fmt"
Expand All @@ -7,55 +7,44 @@ import (
"time"

"github.com/facebookgo/clock"
"github.com/ntindall/speedbump"
contribredis "github.com/ntindall/speedbump/contrib/github.com/go-redis/redis.v5"
"github.com/ntindall/speedbump/internal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"gopkg.in/redis.v5"
redis "gopkg.in/redis.v5"
)

func createClient() *redis.Client {
func createClient() internal.RedisClient {
addr := "localhost:6379"
if os.Getenv("WERCKER_REDIS_HOST") != "" {
return redis.NewClient(&redis.Options{
Addr: os.Getenv("WERCKER_REDIS_HOST") + ":6379",
Password: "",
DB: 0,
})
addr = os.Getenv("WERCKER_REDIS_HOST") + ":6379"
}

return redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "",
DB: 0,
})
}

func teardown(t *testing.T, client *redis.Client) {
// Flush Redis.
require.NoError(t, client.FlushAll().Err())
return contribredis.NewRedisClient(
redis.NewClient(&redis.Options{
Addr: addr,
Password: "",
DB: 0,
}),
)
}

func TestNewLimiter(t *testing.T) {
client := createClient()
hasher := PerSecondHasher{}
max := int64(10)
actual := NewLimiter(client, hasher, max)
func teardown(t *testing.T, client internal.RedisClient) {

assert.Exactly(t, RateLimiter{
redisClient: client,
hasher: hasher,
max: max,
}, *actual)
// Flush Redis.
require.NoError(t, client.(*contribredis.Wrapper).FlushAll().Err())
}

func ExampleNewLimiter() {
// Create a Redis client.
client := createClient()

// Create a new hasher.
hasher := PerSecondHasher{}
hasher := speedbump.PerSecondHasher{}

// Create a new limiter that will only allow 10 requests per second.
limiter := NewLimiter(client, hasher, 10)
limiter := speedbump.NewLimiter(client, hasher, 10)

fmt.Println(limiter.Attempt("127.0.0.1"))
// Output: true <nil>
Expand All @@ -66,7 +55,7 @@ func TestHas(t *testing.T) {
client := createClient()
defer teardown(t, client)
// Create limiter of 5 requests/min.
limiter := NewLimiter(client, PerMinuteHasher{}, 5)
limiter := speedbump.NewLimiter(client, speedbump.PerMinuteHasher{}, 5)
// Choose an arbitrary id.
testID := "test_id"

Expand Down Expand Up @@ -106,11 +95,11 @@ func TestAttempt(t *testing.T) {
defer teardown(t, client)
// Create PerMinuteHasher with mock clock.
mock := clock.NewMock()
hasher := PerMinuteHasher{
hasher := speedbump.PerMinuteHasher{
Clock: mock,
}
// Create limiter of 5 requests/min.
limiter := NewLimiter(client, hasher, 5)
limiter := speedbump.NewLimiter(client, hasher, 5)
// Choose an arbitrary id.
testID := "test_id"
// Ensure no key exists before first request for testID.
Expand Down Expand Up @@ -217,7 +206,7 @@ func TestAttempt(t *testing.T) {
assert.True(t, ok, "Attempts returned false after waiting for interval")
}

func makeNAttempts(t *testing.T, limiter *RateLimiter, id string, n int64) {
func makeNAttempts(t *testing.T, limiter *speedbump.RateLimiter, id string, n int64) {
var i int64
for i = 0; i < n; i++ {
_, err := limiter.Attempt(id)
Expand All @@ -231,12 +220,12 @@ func TestAttemptedLeft(t *testing.T) {
defer teardown(t, client)
// Create PerMinuteHasher with mock clock.
mock := clock.NewMock()
hasher := PerMinuteHasher{
hasher := speedbump.PerMinuteHasher{
Clock: mock,
}
max := int64(5)
// Create limiter of 5 requests/min.
limiter := NewLimiter(client, hasher, max)
limiter := speedbump.NewLimiter(client, hasher, max)
// Choose an arbitrary id.
testID := "test_id"

Expand Down