Skip to content

Commit

Permalink
Add unary and streaming client-side rate-limit interceptors (#520)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Khairwar committed Mar 22, 2023
1 parent dd1540e commit 7801504
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 15 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.14

require (
github.com/golang/protobuf v1.5.2
github.com/opentracing/opentracing-go v1.2.0
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.7.0
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6
Expand Down
3 changes: 0 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,13 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
Expand Down
45 changes: 42 additions & 3 deletions interceptors/ratelimit/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,56 @@ func (*alwaysPassLimiter) Limit(_ context.Context) error {
return nil
}

// Simple example of server initialization code.
func Example() {
// Simple example of a unary server initialization code.
func ExampleUnaryServerInterceptor() {
// Create unary/stream rateLimiters, based on token bucket here.
// You can implement your own ratelimiter for the interface.
// You can implement your own rate-limiter for the interface.
limiter := &alwaysPassLimiter{}
_ = grpc.NewServer(
grpc.ChainUnaryInterceptor(
ratelimit.UnaryServerInterceptor(limiter),
),
)
}

// Simple example of a streaming server initialization code.
func ExampleStreamServerInterceptor() {
// Create unary/stream rateLimiters, based on token bucket here.
// You can implement your own rate-limiter for the interface.
limiter := &alwaysPassLimiter{}
_ = grpc.NewServer(
grpc.ChainStreamInterceptor(
ratelimit.StreamServerInterceptor(limiter),
),
)
}

// Simple example of a unary client initialization code.
func ExampleUnaryClientInterceptor() {
// Create stream rateLimiter, based on token bucket here.
// You can implement your own rate-limiter for the interface.
limiter := &alwaysPassLimiter{}
_, _ = grpc.DialContext(
context.Background(),
":8080",
grpc.WithInsecure(),
grpc.WithUnaryInterceptor(
ratelimit.UnaryClientInterceptor(limiter),
),
)
}

// Simple example of a streaming client initialization code.
func ExampleStreamClientInterceptor() {
// Create stream rateLimiter, based on token bucket here.
// You can implement your own rate-limiter for the interface.
limiter := &alwaysPassLimiter{}
_, _ = grpc.DialContext(
context.Background(),
":8080",
grpc.WithInsecure(),
grpc.WithChainStreamInterceptor(
ratelimit.StreamClientInterceptor(limiter),
),
)
}
27 changes: 27 additions & 0 deletions interceptors/ratelimit/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,30 @@ func StreamServerInterceptor(limiter Limiter) grpc.StreamServerInterceptor {
return handler(srv, stream)
}
}

// UnaryClientInterceptor returns a new unary client interceptor that performs rate limiting on the request on the
// client side.
// This can be helpful for clients that want to limit the number of requests they send in a given time, potentially
// saving cost.
func UnaryClientInterceptor(limiter Limiter) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := limiter.Limit(ctx); err != nil {
return status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later. %s", method, err)
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}

// StreamClientInterceptor returns a new stream client interceptor that performs rate limiting on the request on the
// client side.
// This can be helpful for clients that want to limit the number of requests they send in a given time, potentially
// saving cost.
func StreamClientInterceptor(limiter Limiter) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
if err := limiter.Limit(ctx); err != nil {
return nil, status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later. %s", method, err)
}
return streamer(ctx, desc, cc, method, opts...)
}
}
108 changes: 100 additions & 8 deletions interceptors/ratelimit/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ import (
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const errMsgFake = "fake error"

var ctxLimitKey = struct{}{}
var ctxKeyShouldLimit = "should_limit"

type mockGRPCServerStream struct {
grpc.ServerStream
Expand All @@ -29,13 +31,18 @@ func (m *mockGRPCServerStream) Context() context.Context {
type mockContextBasedLimiter struct{}

func (*mockContextBasedLimiter) Limit(ctx context.Context) error {
l, _ := ctx.Value(ctxLimitKey).(error)
return l
shouldLimit, _ := ctx.Value(ctxKeyShouldLimit).(bool)

if shouldLimit {
return errors.New("rate limit exceeded")
}

return nil
}

func TestUnaryServerInterceptor_RateLimitPass(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxLimitKey, false)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, false)

interceptor := UnaryServerInterceptor(limiter)
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
Expand All @@ -51,7 +58,7 @@ func TestUnaryServerInterceptor_RateLimitPass(t *testing.T) {

func TestStreamServerInterceptor_RateLimitPass(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxLimitKey, false)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, false)

interceptor := StreamServerInterceptor(limiter)
handler := func(srv interface{}, stream grpc.ServerStream) error {
Expand All @@ -66,31 +73,116 @@ func TestStreamServerInterceptor_RateLimitPass(t *testing.T) {

func TestUnaryServerInterceptor_RateLimitFail(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxLimitKey, true)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, true)

interceptor := UnaryServerInterceptor(limiter)
called := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
called = true
return nil, errors.New(errMsgFake)
}
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}
resp, err := interceptor(ctx, nil, info, handler)
expErr := status.Errorf(
codes.ResourceExhausted,
"%s is rejected by grpc_ratelimit middleware, please retry later. %s",
info.FullMethod,
"rate limit exceeded",
)
assert.Nil(t, resp)
assert.EqualError(t, err, errMsgFake)
assert.EqualError(t, err, expErr.Error())
assert.False(t, called)
}

func TestStreamServerInterceptor_RateLimitFail(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxLimitKey, true)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, true)

interceptor := StreamServerInterceptor(limiter)
called := false
handler := func(srv interface{}, stream grpc.ServerStream) error {
called = true
return errors.New(errMsgFake)
}
info := &grpc.StreamServerInfo{
FullMethod: "FakeMethod",
}
err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, info, handler)
expErr := status.Errorf(
codes.ResourceExhausted,
"%s is rejected by grpc_ratelimit middleware, please retry later. %s",
info.FullMethod,
"rate limit exceeded",
)

assert.EqualError(t, err, expErr.Error())
assert.False(t, called)
}

func TestUnaryClientInterceptor_RateLimitPass(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, false)

interceptor := UnaryClientInterceptor(limiter)
invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return errors.New(errMsgFake)
}
err := interceptor(ctx, "FakeMethod", nil, nil, nil, invoker)
assert.EqualError(t, err, errMsgFake)
}

func TestStreamClientInterceptor_RateLimitPass(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, false)

interceptor := StreamClientInterceptor(limiter)
invoker := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return nil, errors.New(errMsgFake)
}
_, err := interceptor(ctx, nil, nil, "FakeMethod", invoker)
assert.EqualError(t, err, errMsgFake)
}

func TestUnaryClientInterceptor_RateLimitFail(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, true)

interceptor := UnaryClientInterceptor(limiter)
called := false
invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
called = true
return errors.New(errMsgFake)
}
err := interceptor(ctx, "FakeMethod", nil, nil, nil, invoker)
expErr := status.Errorf(
codes.ResourceExhausted,
"%s is rejected by grpc_ratelimit middleware, please retry later. %s",
"FakeMethod",
"rate limit exceeded",
)
assert.EqualError(t, err, expErr.Error())
assert.False(t, called)
}

func TestStreamClientInterceptor_RateLimitFail(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, true)

interceptor := StreamClientInterceptor(limiter)
called := false
invoker := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
called = true
return nil, errors.New(errMsgFake)
}
_, err := interceptor(ctx, nil, nil, "FakeMethod", invoker)
expErr := status.Errorf(
codes.ResourceExhausted,
"%s is rejected by grpc_ratelimit middleware, please retry later. %s",
"FakeMethod",
"rate limit exceeded",
)
assert.EqualError(t, err, expErr.Error())
assert.False(t, called)
}

0 comments on commit 7801504

Please sign in to comment.