Skip to content

Commit 926bed3

Browse files
authored
Merge pull request #474 from adamdecaf/ratex-improvements
ratex: misc improvements from Grok
2 parents 6aeb726 + 06efade commit 926bed3

File tree

2 files changed

+231
-9
lines changed

2 files changed

+231
-9
lines changed

ratex/ratelimit.go

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/rand"
66
"fmt"
7+
"math"
78
"math/big"
89
"time"
910

@@ -72,16 +73,32 @@ func generateRateLimiter(ctx context.Context, params RateLimitParams) (*rate.Lim
7273
}
7374

7475
// generateRateLimitDuration returns a random value between min-max duration multiplied by the multiplier.
76+
// It handles cases where max <= min and includes overflow protection.
7577
func generateRateLimitDuration(multiplier int, minDuration, maxDuration time.Duration) (time.Duration, error) {
7678
minVal := minDuration.Milliseconds()
7779
maxVal := maxDuration.Milliseconds()
7880

79-
maxRand, err := rand.Int(rand.Reader, big.NewInt(maxVal-minVal))
81+
if maxVal <= minVal {
82+
// If max <= min, use min * multiplier (or return error if preferred)
83+
waitInterval := minVal * int64(multiplier)
84+
if waitInterval < 0 || waitInterval > math.MaxInt64/1000000 { // Arbitrary cap to prevent overflow
85+
waitInterval = minVal // Fallback to min if overflow
86+
}
87+
return time.Duration(waitInterval) * time.Millisecond, nil
88+
}
89+
90+
delta := maxVal - minVal
91+
maxRand, err := rand.Int(rand.Reader, big.NewInt(delta))
8092
if err != nil {
8193
return 0, fmt.Errorf("rand int: %w", err)
8294
}
95+
8396
waitInterval := (minVal + maxRand.Int64()) * int64(multiplier)
84-
return time.Millisecond * time.Duration(waitInterval), nil
97+
if waitInterval < 0 || waitInterval > math.MaxInt64/1000000 { // Arbitrary cap to prevent overflow
98+
waitInterval = maxVal * int64(multiplier) // Cap at max * multiplier
99+
}
100+
101+
return time.Duration(waitInterval) * time.Millisecond, nil
85102
}
86103

87104
type RetryParams struct {
@@ -98,6 +115,17 @@ func ExecRetryable[R any](ctx context.Context, closure func(ctx context.Context)
98115
err error
99116
)
100117

118+
// Validate params
119+
if params.MaxRetries <= 0 {
120+
params.MaxRetries = 1 // Default to at least one try
121+
}
122+
if params.MinDuration <= 0 {
123+
params.MinDuration = 100 * time.Millisecond // Default min backoff
124+
}
125+
if params.MaxDuration < params.MinDuration {
126+
params.MaxDuration = params.MinDuration * 10 // Default max to 10x min
127+
}
128+
101129
retryFunc := func(ctx context.Context, retryAttempt int) (R, error) {
102130
tryCtx, span := telemetry.StartSpan(ctx, "try",
103131
trace.WithAttributes(
@@ -109,7 +137,7 @@ func ExecRetryable[R any](ctx context.Context, closure func(ctx context.Context)
109137
return closure(tryCtx)
110138
}
111139

112-
for i := range params.MaxRetries {
140+
for i := 0; i < params.MaxRetries; i++ {
113141
retryAttempt := i + 1
114142
retVal, err = retryFunc(ctx, retryAttempt)
115143

@@ -132,17 +160,17 @@ func ExecRetryable[R any](ctx context.Context, closure func(ctx context.Context)
132160
// generate rate limiter to delay retries.
133161
// This will jitter a wait time before the next iteration.
134162
//
135-
// We continue on rate limit errors and retry without waiting
136-
params := RateLimitParams{
163+
// We abort on rate limit errors (e.g., ctx cancel) instead of continuing
164+
rlParams := RateLimitParams{
137165
RateLimiter: rateLimiter,
138166
RetryAttempt: retryAttempt,
139167
MinDuration: params.MinDuration,
140168
MaxDuration: params.MaxDuration,
141169
}
142-
rateLimiter, err = RateLimit(ctx, params)
170+
rateLimiter, err = RateLimit(ctx, rlParams)
143171
if err != nil {
144172
telemetry.AddEvent(ctx, fmt.Sprintf("rate limit: %s", err.Error()))
145-
continue
173+
return retVal, err // Abort on error (e.g., context canceled)
146174
}
147175
}
148176
}

ratex/ratelimit_test.go

Lines changed: 196 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ package ratex
33
import (
44
"context"
55
"errors"
6+
"math"
67
"testing"
78
"time"
89

910
"github.com/stretchr/testify/require"
11+
"golang.org/x/time/rate"
1012
)
1113

1214
func TestExecRetryable(t *testing.T) {
@@ -30,8 +32,8 @@ func TestExecRetryable(t *testing.T) {
3032
t.Run("Retryable failure with success before last retry", func(t *testing.T) {
3133
attempts := 0
3234
closure := func(ctx context.Context) (string, error) {
33-
if attempts < 2 {
34-
attempts++
35+
attempts++
36+
if attempts < 3 {
3537
return "", errors.New("retryable error")
3638
}
3739
return "success", nil
@@ -42,9 +44,17 @@ func TestExecRetryable(t *testing.T) {
4244
MinDuration: 10 * time.Millisecond,
4345
MaxDuration: 50 * time.Millisecond,
4446
}
47+
start := time.Now()
4548
result, err := ExecRetryable(ctx, closure, params)
49+
elapsed := time.Since(start)
4650
require.NoErrorf(t, err, "Expected success, got error: %v", err)
4751
require.Equalf(t, "success", result, "Expected result 'success', got: %v", result)
52+
require.Equal(t, 3, attempts, "Expected 3 attempts")
53+
// Check approximate backoff time (2 backoffs: ~10-50ms *1 + ~10-50ms *2)
54+
minElapsed := params.MinDuration*1 + params.MinDuration*2
55+
maxElapsed := params.MaxDuration*1 + params.MaxDuration*2 + 50*time.Millisecond // Overhead allowance
56+
require.GreaterOrEqual(t, elapsed, minElapsed, "Elapsed time too short")
57+
require.LessOrEqual(t, elapsed, maxElapsed, "Elapsed time too long")
4858
})
4959

5060
t.Run("Non-retryable failure", func(t *testing.T) {
@@ -64,7 +74,9 @@ func TestExecRetryable(t *testing.T) {
6474
})
6575

6676
t.Run("Retryable failures exceeding MaxRetries", func(t *testing.T) {
77+
attempts := 0
6778
closure := func(ctx context.Context) (string, error) {
79+
attempts++
6880
return "", errors.New("retryable error")
6981
}
7082
params := RetryParams{
@@ -76,6 +88,188 @@ func TestExecRetryable(t *testing.T) {
7688
result, err := ExecRetryable(ctx, closure, params)
7789
require.Errorf(t, err, "Expected error after exceeding max retries, got: %v", err)
7890
require.Empty(t, result)
91+
require.Equal(t, 3, attempts, "Expected 3 attempts")
7992
require.Equal(t, "hit max tries 3: try 3 of 3: retryable error", err.Error())
8093
})
94+
95+
t.Run("Context cancellation during closure", func(t *testing.T) {
96+
ctx, cancel := context.WithCancel(context.Background())
97+
defer cancel()
98+
99+
attempts := 0
100+
closure := func(ctx context.Context) (string, error) {
101+
attempts++
102+
if attempts == 2 {
103+
cancel()
104+
}
105+
return "", errors.New("retryable error")
106+
}
107+
params := RetryParams{
108+
ShouldRetry: func(err error) bool { return true },
109+
MaxRetries: 3,
110+
MinDuration: 10 * time.Millisecond,
111+
MaxDuration: 50 * time.Millisecond,
112+
}
113+
result, err := ExecRetryable(ctx, closure, params)
114+
require.ErrorIs(t, err, context.Canceled)
115+
require.Empty(t, result)
116+
require.Equal(t, 2, attempts, "Expected 2 attempts before cancel")
117+
})
118+
119+
t.Run("Context cancellation during backoff", func(t *testing.T) {
120+
ctx, cancel := context.WithCancel(context.Background())
121+
122+
closure := func(ctx context.Context) (string, error) {
123+
return "", errors.New("retryable error")
124+
}
125+
params := RetryParams{
126+
ShouldRetry: func(err error) bool { return true },
127+
MaxRetries: 3,
128+
MinDuration: 100 * time.Millisecond,
129+
MaxDuration: 200 * time.Millisecond,
130+
}
131+
go func() {
132+
time.Sleep(50 * time.Millisecond) // Cancel mid-backoff
133+
cancel()
134+
}()
135+
result, err := ExecRetryable(ctx, closure, params)
136+
require.ErrorIs(t, err, context.Canceled)
137+
require.Empty(t, result)
138+
})
139+
140+
t.Run("Invalid params with defaults", func(t *testing.T) {
141+
closure := func(ctx context.Context) (string, error) {
142+
return "success", nil
143+
}
144+
params := RetryParams{
145+
ShouldRetry: func(err error) bool { return true },
146+
MaxRetries: 0, // Should default to 1
147+
MinDuration: 0, // Should default to 100ms
148+
MaxDuration: 0, // Should default to 100ms * 10 = 1s
149+
}
150+
result, err := ExecRetryable(ctx, closure, params)
151+
require.NoError(t, err)
152+
require.Equal(t, "success", result)
153+
// Note: Defaults are applied, but no retries needed here
154+
})
155+
156+
t.Run("MinDuration > MaxDuration", func(t *testing.T) {
157+
closure := func(ctx context.Context) (string, error) {
158+
return "success", nil
159+
}
160+
params := RetryParams{
161+
ShouldRetry: func(err error) bool { return true },
162+
MaxRetries: 3,
163+
MinDuration: 100 * time.Millisecond,
164+
MaxDuration: 50 * time.Millisecond, // Will be set to 100ms * 10 = 1s
165+
}
166+
result, err := ExecRetryable(ctx, closure, params)
167+
require.NoError(t, err)
168+
require.Equal(t, "success", result)
169+
})
170+
171+
t.Run("Different return type", func(t *testing.T) {
172+
closure := func(ctx context.Context) (int, error) {
173+
return 42, nil
174+
}
175+
params := RetryParams{
176+
ShouldRetry: func(err error) bool { return true },
177+
MaxRetries: 3,
178+
MinDuration: 10 * time.Millisecond,
179+
MaxDuration: 50 * time.Millisecond,
180+
}
181+
result, err := ExecRetryable(ctx, closure, params)
182+
require.NoError(t, err)
183+
require.Equal(t, 42, result)
184+
})
185+
}
186+
187+
func TestGenerateRateLimitDuration(t *testing.T) {
188+
t.Run("Standard case", func(t *testing.T) {
189+
for i := 0; i < 10; i++ { // Run multiple times to check randomness
190+
dur, err := generateRateLimitDuration(1, 100*time.Millisecond, 200*time.Millisecond)
191+
require.NoError(t, err)
192+
require.GreaterOrEqual(t, dur, 100*time.Millisecond)
193+
require.LessOrEqual(t, dur, 200*time.Millisecond)
194+
}
195+
})
196+
197+
t.Run("Max <= Min", func(t *testing.T) {
198+
dur, err := generateRateLimitDuration(2, 100*time.Millisecond, 50*time.Millisecond)
199+
require.NoError(t, err)
200+
require.Equal(t, 200*time.Millisecond, dur) // min * multiplier
201+
})
202+
203+
t.Run("Delta == 0", func(t *testing.T) {
204+
dur, err := generateRateLimitDuration(3, 50*time.Millisecond, 50*time.Millisecond)
205+
require.NoError(t, err)
206+
require.Equal(t, 150*time.Millisecond, dur)
207+
})
208+
209+
t.Run("Overflow cap", func(t *testing.T) {
210+
// Set large multiplier to trigger cap
211+
// Assume cap at math.MaxInt64 / 1000000 ~ 9e12 ms (~104 days)
212+
// Use minVal=1ms, multiplier such that 1 * mul > 9e12
213+
largeMul := int(math.MaxInt64 / 1000000 / 2) // Safe large int
214+
dur, err := generateRateLimitDuration(largeMul, 1*time.Millisecond, 2*time.Millisecond)
215+
require.NoError(t, err)
216+
// Since min + rand(0 or 1) * largeMul, but cap to max * mul = 2 * largeMul ms
217+
require.LessOrEqual(t, int64(dur.Milliseconds()), 2*int64(largeMul))
218+
})
219+
220+
t.Run("Negative multiplier (edge case)", func(t *testing.T) {
221+
dur, err := generateRateLimitDuration(-1, 100*time.Millisecond, 200*time.Millisecond)
222+
require.NoError(t, err)
223+
// Since waitInterval negative, cap kicks in, but code sets to max * mul if overflow/negative
224+
// But mul negative, so waitInterval negative, capped to max * mul (negative, but duration cast)
225+
// Actually, code checks waitInterval <0, sets to maxVal * int64(multiplier)
226+
// If mul negative, this would be negative, but time.Duration negative is invalid
227+
// Note: This test highlights potential issue, but multiplier is always positive in usage
228+
require.LessOrEqual(t, dur, time.Duration(0)) // Expect non-positive
229+
})
230+
}
231+
232+
func TestRateLimit(t *testing.T) {
233+
ctx := context.Background()
234+
235+
t.Run("New limiter", func(t *testing.T) {
236+
params := RateLimitParams{
237+
RateLimiter: nil,
238+
RetryAttempt: 1,
239+
MinDuration: 10 * time.Millisecond,
240+
MaxDuration: 20 * time.Millisecond,
241+
}
242+
limiter, err := RateLimit(ctx, params)
243+
require.NoError(t, err)
244+
require.NotNil(t, limiter)
245+
})
246+
247+
t.Run("Existing limiter update", func(t *testing.T) {
248+
existing := rate.NewLimiter(rate.Every(100*time.Millisecond), 1)
249+
params := RateLimitParams{
250+
RateLimiter: existing,
251+
RetryAttempt: 2,
252+
MinDuration: 10 * time.Millisecond,
253+
MaxDuration: 20 * time.Millisecond,
254+
}
255+
limiter, err := RateLimit(ctx, params)
256+
require.NoError(t, err)
257+
require.Equal(t, existing, limiter)
258+
})
259+
260+
t.Run("Context cancel during wait", func(t *testing.T) {
261+
ctx, cancel := context.WithCancel(context.Background())
262+
params := RateLimitParams{
263+
RateLimiter: nil,
264+
RetryAttempt: 1,
265+
MinDuration: 100 * time.Millisecond,
266+
MaxDuration: 200 * time.Millisecond,
267+
}
268+
go func() {
269+
time.Sleep(50 * time.Millisecond)
270+
cancel()
271+
}()
272+
_, err := RateLimit(ctx, params)
273+
require.ErrorIs(t, err, context.Canceled)
274+
})
81275
}

0 commit comments

Comments
 (0)