Skip to content

Commit 1b2b891

Browse files
authored
fix: Handle pool get ctx cancellation promptly (gomodule#470)
Handle the cancellation of a context promptly when waiting for a vacant connection from the pool.
1 parent 941d323 commit 1b2b891

File tree

2 files changed

+83
-37
lines changed

2 files changed

+83
-37
lines changed

redis/pool.go

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -199,27 +199,10 @@ func (p *Pool) Get() Conn {
199199
// If the function completes without error, then the application must close the
200200
// returned connection.
201201
func (p *Pool) GetContext(ctx context.Context) (Conn, error) {
202-
// Handle limit for p.Wait == true.
203-
var waited time.Duration
204-
if p.Wait && p.MaxActive > 0 {
205-
p.lazyInit()
206-
207-
// wait indicates if we believe it will block so its not 100% accurate
208-
// however for stats it should be good enough.
209-
wait := len(p.ch) == 0
210-
var start time.Time
211-
if wait {
212-
start = time.Now()
213-
}
214-
select {
215-
case <-p.ch:
216-
case <-ctx.Done():
217-
err := ctx.Err()
218-
return errorConn{err}, err
219-
}
220-
if wait {
221-
waited = time.Since(start)
222-
}
202+
// Wait until there is a vacant connection in the pool.
203+
waited, err := p.waitVacantConn(ctx)
204+
if err != nil {
205+
return nil, err
223206
}
224207

225208
p.mu.Lock()
@@ -376,6 +359,51 @@ func (p *Pool) lazyInit() {
376359
p.mu.Unlock()
377360
}
378361

362+
// waitVacantConn waits for a vacant connection in pool if waiting
363+
// is enabled and pool size is limited, otherwise returns instantly.
364+
// If ctx expires before that, an error is returned.
365+
//
366+
// If there were no vacant connection in the pool right away it returns the time spent waiting
367+
// for that connection to appear in the pool.
368+
func (p *Pool) waitVacantConn(ctx context.Context) (waited time.Duration, err error) {
369+
if !p.Wait || p.MaxActive <= 0 {
370+
// No wait or no connection limit.
371+
return 0, nil
372+
}
373+
374+
p.lazyInit()
375+
376+
// wait indicates if we believe it will block so its not 100% accurate
377+
// however for stats it should be good enough.
378+
wait := len(p.ch) == 0
379+
var start time.Time
380+
if wait {
381+
start = time.Now()
382+
}
383+
384+
if ctx == nil {
385+
<-p.ch
386+
} else {
387+
select {
388+
case <-p.ch:
389+
// Additionally check that context hasn't expired while we were waiting,
390+
// because `select` picks a random `case` if several of them are "ready".
391+
select {
392+
case <-ctx.Done():
393+
return 0, ctx.Err()
394+
default:
395+
}
396+
case <-ctx.Done():
397+
return 0, ctx.Err()
398+
}
399+
}
400+
401+
if wait {
402+
return time.Since(start), nil
403+
}
404+
return 0, nil
405+
}
406+
379407
func (p *Pool) dial(ctx context.Context) (Conn, error) {
380408
if p.DialContext != nil {
381409
return p.DialContext(ctx)

redis/pool_test.go

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -899,20 +899,38 @@ func TestWaitPoolGetAfterClose(t *testing.T) {
899899
}
900900

901901
func TestWaitPoolGetCanceledContext(t *testing.T) {
902-
d := poolDialer{t: t}
903-
p := &redis.Pool{
904-
MaxIdle: 1,
905-
MaxActive: 1,
906-
Dial: d.dial,
907-
Wait: true,
908-
}
909-
defer p.Close()
910-
ctx, f := context.WithCancel(context.Background())
911-
f()
912-
c := p.Get()
913-
defer c.Close()
914-
_, err := p.GetContext(ctx)
915-
if err != context.Canceled {
916-
t.Fatalf("got error %v, want %v", err, context.Canceled)
917-
}
902+
t.Run("without vacant connection in the pool", func(t *testing.T) {
903+
d := poolDialer{t: t}
904+
p := &redis.Pool{
905+
MaxIdle: 1,
906+
MaxActive: 1,
907+
Dial: d.dial,
908+
Wait: true,
909+
}
910+
defer p.Close()
911+
ctx, cancel := context.WithCancel(context.Background())
912+
cancel()
913+
c := p.Get()
914+
defer c.Close()
915+
_, err := p.GetContext(ctx)
916+
if err != context.Canceled {
917+
t.Fatalf("got error %v, want %v", err, context.Canceled)
918+
}
919+
})
920+
t.Run("with vacant connection in the pool", func(t *testing.T) {
921+
d := poolDialer{t: t}
922+
p := &redis.Pool{
923+
MaxIdle: 1,
924+
MaxActive: 1,
925+
Dial: d.dial,
926+
Wait: true,
927+
}
928+
defer p.Close()
929+
ctx, cancel := context.WithCancel(context.Background())
930+
cancel()
931+
_, err := p.GetContext(ctx)
932+
if err != context.Canceled {
933+
t.Fatalf("got error %v, want %v", err, context.Canceled)
934+
}
935+
})
918936
}

0 commit comments

Comments
 (0)