Skip to content

Commit

Permalink
Support retriable func condition (#687)
Browse files Browse the repository at this point in the history
* Support retriable func condition

* Refactor WithCodes option
  • Loading branch information
tamayika authored Jan 29, 2024
1 parent f6f8eae commit 220740b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 12 deletions.
34 changes: 31 additions & 3 deletions interceptors/retry/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

var (
Expand All @@ -22,11 +23,11 @@ var (
max: 0, // disabled
perCallTimeout: 0, // disabled
includeHeader: true,
codes: DefaultRetriableCodes,
backoffFunc: BackoffLinearWithJitter(50*time.Millisecond /*jitter*/, 0.10),
onRetryCallback: OnRetryCallback(func(ctx context.Context, attempt uint, err error) {
logTrace(ctx, "grpc_retry attempt: %d, backoff for %v", attempt, err)
}),
retriableFunc: newRetriableFuncForCodes(DefaultRetriableCodes),
}
)

Expand All @@ -41,6 +42,9 @@ type BackoffFunc func(ctx context.Context, attempt uint) time.Duration
// OnRetryCallback is the type of function called when a retry occurs.
type OnRetryCallback func(ctx context.Context, attempt uint, err error)

// RetriableFunc denotes a family of functions that control which error should be retried.
type RetriableFunc func(err error) bool

// Disable disables the retry behaviour on this call, or this interceptor.
//
// Its semantically the same to `WithMax`
Expand Down Expand Up @@ -78,7 +82,7 @@ func WithOnRetryCallback(fn OnRetryCallback) CallOption {
// You cannot automatically retry on Cancelled and Deadline, please use `WithPerRetryTimeout` for these.
func WithCodes(retryCodes ...codes.Code) CallOption {
return CallOption{applyFunc: func(o *options) {
o.codes = retryCodes
o.retriableFunc = newRetriableFuncForCodes(retryCodes)
}}
}

Expand All @@ -100,13 +104,20 @@ func WithPerRetryTimeout(timeout time.Duration) CallOption {
}}
}

// WithRetriable sets which error should be retried.
func WithRetriable(retriableFunc RetriableFunc) CallOption {
return CallOption{applyFunc: func(o *options) {
o.retriableFunc = retriableFunc
}}
}

type options struct {
max uint
perCallTimeout time.Duration
includeHeader bool
codes []codes.Code
backoffFunc BackoffFunc
onRetryCallback OnRetryCallback
retriableFunc RetriableFunc
}

// CallOption is a grpc.CallOption that is local to grpc_retry.
Expand Down Expand Up @@ -137,3 +148,20 @@ func filterCallOptions(callOptions []grpc.CallOption) (grpcOptions []grpc.CallOp
}
return grpcOptions, retryOptions
}

// newRetriableFuncForCodes returns retriable function for specific Codes.
func newRetriableFuncForCodes(codes []codes.Code) func(err error) bool {
return func(err error) bool {
errCode := status.Code(err)
if isContextError(err) {
// context errors are not retriable based on user settings.
return false
}
for _, code := range codes {
if code == errCode {
return true
}
}
return false
}
}
11 changes: 2 additions & 9 deletions interceptors/retry/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,8 @@ func waitRetryBackoff(attempt uint, parentCtx context.Context, callOpts *options
}

func isRetriable(err error, callOpts *options) bool {
errCode := status.Code(err)
if isContextError(err) {
// context errors are not retriable based on user settings.
return false
}
for _, code := range callOpts.codes {
if code == errCode {
return true
}
if callOpts.retriableFunc != nil {
return callOpts.retriableFunc(err)
}
return false
}
Expand Down
21 changes: 21 additions & 0 deletions interceptors/retry/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package retry
import (
"context"
"io"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -178,6 +179,16 @@ func (s *RetrySuite) TestUnary_OverrideFromDialOpts() {
require.EqualValues(s.T(), 5, s.srv.requestCount(), "five requests should have been made")
}

func (s *RetrySuite) TestUnary_OverrideFromDialOpts2() {
s.srv.resetFailingConfiguration(5, codes.ResourceExhausted, noSleep) // default is 3 and retriable_errors
out, err := s.Client.Ping(s.SimpleCtx(), testpb.GoodPing, WithRetriable(func(err error) bool {
return strings.Contains(err.Error(), "maybeFailRequest")
}), WithMax(5))
require.NoError(s.T(), err, "the fifth invocation should succeed")
require.NotNil(s.T(), out, "Pong must be not nil")
require.EqualValues(s.T(), 5, s.srv.requestCount(), "five requests should have been made")
}

func (s *RetrySuite) TestUnary_OnRetryCallbackCalled() {
retryCallbackCount := 0

Expand Down Expand Up @@ -209,6 +220,16 @@ func (s *RetrySuite) TestServerStream_OverrideFromContext() {
require.EqualValues(s.T(), 5, s.srv.requestCount(), "three requests should have been made")
}

func (s *RetrySuite) TestServerStream_OverrideFromContext2() {
s.srv.resetFailingConfiguration(5, codes.ResourceExhausted, noSleep) // default is 3 and retriable_errors
stream, err := s.Client.PingList(s.SimpleCtx(), testpb.GoodPingList, WithRetriable(func(err error) bool {
return strings.Contains(err.Error(), "maybeFailRequest")
}), WithMax(5))
require.NoError(s.T(), err, "establishing the connection must always succeed")
s.assertPingListWasCorrect(stream)
require.EqualValues(s.T(), 5, s.srv.requestCount(), "three requests should have been made")
}

func (s *RetrySuite) TestServerStream_OnRetryCallbackCalled() {
retryCallbackCount := 0

Expand Down

0 comments on commit 220740b

Please sign in to comment.