Skip to content

Commit 5fe0bfa

Browse files
committed
fix(pool): wip, pool reauth should not interfere with handoff
1 parent 3ad9f9c commit 5fe0bfa

File tree

6 files changed

+158
-22
lines changed

6 files changed

+158
-22
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package auth
2+
3+
import (
4+
"time"
5+
6+
"github.com/redis/go-redis/v9/internal/pool"
7+
)
8+
9+
// ConnReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
10+
// It is used to re-authenticate the credentials when they are updated.
11+
// It holds reference to the connection to re-authenticate and will pass it to the reAuth and onErr callbacks.
12+
// It contains:
13+
// - reAuth: a function that takes the new credentials and returns an error if any.
14+
// - onErr: a function that takes an error and handles it.
15+
// - conn: the connection to re-authenticate.
16+
type ConnReAuthCredentialsListener struct {
17+
reAuth func(conn *pool.Conn, credentials Credentials) error
18+
onErr func(conn *pool.Conn, err error)
19+
conn *pool.Conn
20+
}
21+
22+
// OnNext is called when the credentials are updated.
23+
// It calls the reAuth function with the new credentials.
24+
// If the reAuth function returns an error, it calls the onErr function with the error.
25+
func (c *ConnReAuthCredentialsListener) OnNext(credentials Credentials) {
26+
if c.conn.IsClosed() {
27+
return
28+
}
29+
30+
if c.reAuth == nil {
31+
return
32+
}
33+
34+
var err error
35+
timeout := time.After(1 * time.Second)
36+
// wait for the connection to be usable
37+
// this is important because the connection pool may be in the process of reconnecting the connection
38+
// and we don't want to interfere with that process
39+
// but we also don't want to block for too long, so incorporate a timeout
40+
for {
41+
// we were able to mark the connection as unusable
42+
if c.conn.Usable.CompareAndSwap(true, false) {
43+
break
44+
}
45+
46+
select {
47+
case <-timeout:
48+
err = pool.ErrConnUnusableTimeout
49+
break
50+
default:
51+
}
52+
}
53+
if err != nil {
54+
c.OnError(err)
55+
return
56+
}
57+
// we set the usable flag, so restore it back to usable after we're done
58+
defer c.conn.SetUsable(true)
59+
60+
err = c.reAuth(c.conn, credentials)
61+
if err != nil {
62+
c.OnError(err)
63+
}
64+
}
65+
66+
// OnError is called when an error occurs.
67+
// It can be called from both the credentials provider and the reAuth function.
68+
func (c *ConnReAuthCredentialsListener) OnError(err error) {
69+
if c.onErr == nil {
70+
return
71+
}
72+
73+
c.onErr(c.conn, err)
74+
}
75+
76+
// NewConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener.
77+
// Implements the auth.CredentialsListener interface.
78+
func NewConnReAuthCredentialsListener(conn *pool.Conn, reAuth func(conn *pool.Conn, credentials Credentials) error, onErr func(conn *pool.Conn, err error)) *ConnReAuthCredentialsListener {
79+
return &ConnReAuthCredentialsListener{
80+
conn: conn,
81+
reAuth: reAuth,
82+
onErr: onErr,
83+
}
84+
}
85+
86+
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
87+
var _ CredentialsListener = (*ConnReAuthCredentialsListener)(nil)

auth/reauth_credentials_listener.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, on
4444
}
4545

4646
// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
47-
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
47+
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)

error.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ func isBadConn(err error, allowTimeout bool, addr string) bool {
112112
return false
113113
case context.Canceled, context.DeadlineExceeded:
114114
return true
115+
case pool.ErrConnUnusableTimeout:
116+
return true
115117
}
116118

117119
if isRedisError(err) {

internal/pool/conn.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ func generateConnID() uint64 {
4040
}
4141

4242
type Conn struct {
43+
// Connection identifier for unique tracking
44+
id uint64 // Unique numeric identifier for this connection
45+
4346
usedAt int64 // atomic
4447

4548
// Lock-free netConn access using atomic.Value
@@ -54,7 +57,9 @@ type Conn struct {
5457
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
5558
readerMu sync.RWMutex
5659

57-
Inited atomic.Bool
60+
Usable atomic.Bool
61+
Inited atomic.Bool
62+
5863
pooled bool
5964
pubsub bool
6065
closed atomic.Bool
@@ -75,18 +80,14 @@ type Conn struct {
7580
// Connection initialization function for reconnections
7681
initConnFunc func(context.Context, *Conn) error
7782

78-
// Connection identifier for unique tracking
79-
id uint64 // Unique numeric identifier for this connection
80-
8183
// Handoff state - using atomic operations for lock-free access
82-
usableAtomic atomic.Bool // Connection usability state
8384
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
8485

8586
// Atomic handoff state to prevent race conditions
8687
// Stores *HandoffState to ensure atomic updates of all handoff-related fields
8788
handoffStateAtomic atomic.Value // stores *HandoffState
8889

89-
onClose func() error
90+
onClose func() error
9091
}
9192

9293
func NewConn(netConn net.Conn) *Conn {
@@ -116,7 +117,7 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
116117
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
117118

118119
// Initialize atomic state
119-
cn.usableAtomic.Store(false) // false initially, set to true after initialization
120+
cn.Usable.Store(false) // false initially, set to true after initialization
120121
cn.handoffRetriesAtomic.Store(0) // 0 initially
121122

122123
// Initialize handoff state atomically
@@ -162,12 +163,12 @@ func (cn *Conn) setNetConn(netConn net.Conn) {
162163

163164
// isUsable returns true if the connection is safe to use (lock-free).
164165
func (cn *Conn) isUsable() bool {
165-
return cn.usableAtomic.Load()
166+
return cn.Usable.Load()
166167
}
167168

168169
// setUsable sets the usable flag atomically (lock-free).
169170
func (cn *Conn) setUsable(usable bool) {
170-
cn.usableAtomic.Store(usable)
171+
cn.Usable.Store(usable)
171172
}
172173

173174
// getHandoffState returns the current handoff state atomically (lock-free).
@@ -456,6 +457,12 @@ func (cn *Conn) MarkQueuedForHandoff() error {
456457
const baseDelay = time.Microsecond
457458

458459
for attempt := 0; attempt < maxRetries; attempt++ {
460+
// first we need to mark the connection as not usable
461+
// to prevent the pool from returning it to the caller
462+
if !cn.Usable.CompareAndSwap(true, false) {
463+
continue
464+
}
465+
459466
currentState := cn.getHandoffState()
460467

461468
// Check if marked for handoff
@@ -472,7 +479,6 @@ func (cn *Conn) MarkQueuedForHandoff() error {
472479

473480
// Atomic compare-and-swap to update state
474481
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
475-
cn.setUsable(false)
476482
return nil
477483
}
478484

internal/pool/pool.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ var (
2424
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
2525
ErrPoolTimeout = errors.New("redis: connection pool timeout")
2626

27+
// ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable.
28+
ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable")
29+
2730
// popAttempts is the maximum number of attempts to find a usable connection
2831
// when popping from the idle connection pool. This handles cases where connections
2932
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).

redis.go

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ type baseClient struct {
224224
// Maintenance notifications manager
225225
maintNotificationsManager *maintnotifications.Manager
226226
maintNotificationsManagerLock sync.RWMutex
227+
228+
credListeners map[uint64]auth.CredentialsListener
229+
credListenersLock sync.RWMutex
227230
}
228231

229232
func (c *baseClient) clone() *baseClient {
@@ -237,6 +240,7 @@ func (c *baseClient) clone() *baseClient {
237240
onClose: c.onClose,
238241
pushProcessor: c.pushProcessor,
239242
maintNotificationsManager: maintNotificationsManager,
243+
credListeners: c.credListeners,
240244
}
241245
return clone
242246
}
@@ -296,18 +300,43 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
296300
return cn, nil
297301
}
298302

299-
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
300-
return auth.NewReAuthCredentialsListener(
301-
c.reAuthConnection(poolCn),
302-
c.onAuthenticationErr(poolCn),
303+
// connReAuthCredentialsListener returns a credentials listener that can be used to re-authenticate the connection.
304+
// The credentials listener is stored in a map, so that it can be reused for multiple connections.
305+
// The credentials listener is removed from the map when the connection is closed.
306+
func (c *baseClient) connReAuthCredentialsListener(poolCn *pool.Conn) (auth.CredentialsListener, func()) {
307+
c.credListenersLock.RLock()
308+
credListener, ok := c.credListeners[poolCn.GetID()]
309+
c.credListenersLock.RUnlock()
310+
if ok {
311+
return credListener.(auth.CredentialsListener), func() {
312+
c.removeCredListener(poolCn)
313+
}
314+
}
315+
c.credListenersLock.Lock()
316+
defer c.credListenersLock.Unlock()
317+
newCredListener := auth.NewConnReAuthCredentialsListener(
318+
poolCn,
319+
c.reAuthConnection(),
320+
c.onAuthenticationErr(),
303321
)
322+
c.credListeners[poolCn.GetID()] = newCredListener
323+
return newCredListener, func() {
324+
c.removeCredListener(poolCn)
325+
}
326+
}
327+
328+
func (c *baseClient) removeCredListener(poolCn *pool.Conn) {
329+
c.credListenersLock.Lock()
330+
defer c.credListenersLock.Unlock()
331+
delete(c.credListeners, poolCn.GetID())
304332
}
305333

306-
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
307-
return func(credentials auth.Credentials) error {
334+
func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error {
335+
return func(poolCn *pool.Conn, credentials auth.Credentials) error {
308336
var err error
309337
username, password := credentials.BasicAuth()
310338
ctx := context.Background()
339+
311340
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
312341
// hooksMixin are intentionally empty here
313342
cn := newConn(c.opt, connPool, nil)
@@ -320,8 +349,8 @@ func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.C
320349
return err
321350
}
322351
}
323-
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
324-
return func(err error) {
352+
func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) {
353+
return func(poolCn *pool.Conn, err error) {
325354
if err != nil {
326355
if isBadConn(err, false, c.opt.Addr) {
327356
// Close the connection to force a reconnection.
@@ -372,13 +401,20 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
372401

373402
username, password := "", ""
374403
if c.opt.StreamingCredentialsProvider != nil {
404+
credListener, removeCredListener := c.connReAuthCredentialsListener(cn)
375405
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
376-
Subscribe(c.newReAuthCredentialsListener(cn))
406+
Subscribe(credListener)
377407
if err != nil {
378408
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
379409
}
380-
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
381-
cn.SetOnClose(unsubscribeFromCredentialsProvider)
410+
411+
unsubscribe := func() error {
412+
removeCredListener()
413+
return unsubscribeFromCredentialsProvider()
414+
}
415+
c.onClose = c.wrappedOnClose(unsubscribe)
416+
cn.SetOnClose(unsubscribe)
417+
382418
username, password = credentials.BasicAuth()
383419
} else if c.opt.CredentialsProviderContext != nil {
384420
username, password, err = c.opt.CredentialsProviderContext(ctx)
@@ -496,6 +532,8 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
496532
}
497533
}
498534

535+
// mark the connection as usable and inited
536+
// once returned to the pool as idle, this connection can be used by other clients
499537
cn.SetUsable(true)
500538
cn.Inited.Store(true)
501539

0 commit comments

Comments
 (0)