@@ -18,6 +18,15 @@ var noDeadline = time.Time{}
1818// Global atomic counter for connection IDs
1919var connIDCounter uint64
2020
21+ // HandoffState represents the atomic state for connection handoffs
22+ // This struct is stored atomically to prevent race conditions between
23+ // checking handoff status and reading handoff parameters
24+ type HandoffState struct {
25+ ShouldHandoff bool // Whether connection should be handed off
26+ Endpoint string // New endpoint for handoff
27+ SeqID int64 // Sequence ID from MOVING notification
28+ }
29+
2130// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
2231type atomicNetConn struct {
2332 conn net.Conn
@@ -69,11 +78,11 @@ type Conn struct {
6978
7079 // Handoff state - using atomic operations for lock-free access
7180 usableAtomic atomic.Bool // Connection usability state
72- shouldHandoffAtomic atomic.Bool // Whether connection should be handed off
73- movingSeqIDAtomic atomic.Int64 // Sequence ID from MOVING notification
7481 handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
75- // newEndpointAtomic needs special handling as it's a string
76- newEndpointAtomic atomic.Value // stores string
82+
83+ // Atomic handoff state to prevent race conditions
84+ // Stores *HandoffState to ensure atomic updates of all handoff-related fields
85+ handoffStateAtomic atomic.Value // stores *HandoffState
7786
7887 onClose func () error
7988}
@@ -104,12 +113,17 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
104113 // Store netConn atomically for lock-free access using wrapper
105114 cn .netConnAtomic .Store (& atomicNetConn {conn : netConn })
106115
107- // Initialize atomic handoff state
116+ // Initialize atomic state
108117 cn .usableAtomic .Store (false ) // false initially, set to true after initialization
109- cn .shouldHandoffAtomic .Store (false ) // false initially
110- cn .movingSeqIDAtomic .Store (0 ) // 0 initially
111118 cn .handoffRetriesAtomic .Store (0 ) // 0 initially
112- cn .newEndpointAtomic .Store ("" ) // empty string initially
119+
120+ // Initialize handoff state atomically
121+ initialHandoffState := & HandoffState {
122+ ShouldHandoff : false ,
123+ Endpoint : "" ,
124+ SeqID : 0 ,
125+ }
126+ cn .handoffStateAtomic .Store (initialHandoffState )
113127
114128 cn .wr = proto .NewWriter (cn .bw )
115129 cn .SetUsedAt (time .Now ())
@@ -154,37 +168,38 @@ func (cn *Conn) setUsable(usable bool) {
154168 cn .usableAtomic .Store (usable )
155169}
156170
157- // shouldHandoff returns true if connection needs handoff (lock-free).
158- func (cn * Conn ) shouldHandoff () bool {
159- return cn .shouldHandoffAtomic .Load ()
171+ // getHandoffState returns the current handoff state atomically (lock-free).
172+ func (cn * Conn ) getHandoffState () * HandoffState {
173+ state := cn .handoffStateAtomic .Load ()
174+ if state == nil {
175+ // Return default state if not initialized
176+ return & HandoffState {
177+ ShouldHandoff : false ,
178+ Endpoint : "" ,
179+ SeqID : 0 ,
180+ }
181+ }
182+ return state .(* HandoffState )
160183}
161184
162- // setShouldHandoff sets the handoff flag atomically (lock-free).
163- func (cn * Conn ) setShouldHandoff ( should bool ) {
164- cn .shouldHandoffAtomic .Store (should )
185+ // setHandoffState sets the handoff state atomically (lock-free).
186+ func (cn * Conn ) setHandoffState ( state * HandoffState ) {
187+ cn .handoffStateAtomic .Store (state )
165188}
166189
167- // getMovingSeqID returns the sequence ID atomically (lock-free).
168- func (cn * Conn ) getMovingSeqID () int64 {
169- return cn .movingSeqIDAtomic . Load ()
190+ // shouldHandoff returns true if connection needs handoff (lock-free).
191+ func (cn * Conn ) shouldHandoff () bool {
192+ return cn .getHandoffState (). ShouldHandoff
170193}
171194
172- // setMovingSeqID sets the sequence ID atomically (lock-free).
173- func (cn * Conn ) setMovingSeqID ( seqID int64 ) {
174- cn .movingSeqIDAtomic . Store ( seqID )
195+ // getMovingSeqID returns the sequence ID atomically (lock-free).
196+ func (cn * Conn ) getMovingSeqID () int64 {
197+ return cn .getHandoffState (). SeqID
175198}
176199
177200// getNewEndpoint returns the new endpoint atomically (lock-free).
178201func (cn * Conn ) getNewEndpoint () string {
179- if endpoint := cn .newEndpointAtomic .Load (); endpoint != nil {
180- return endpoint .(string )
181- }
182- return ""
183- }
184-
185- // setNewEndpoint sets the new endpoint atomically (lock-free).
186- func (cn * Conn ) setNewEndpoint (endpoint string ) {
187- cn .newEndpointAtomic .Store (endpoint )
202+ return cn .getHandoffState ().Endpoint
188203}
189204
190205// setHandoffRetries sets the retry count atomically (lock-free).
@@ -396,24 +411,74 @@ func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) err
396411
397412// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free).
398413// Returns an error if the connection is already marked for handoff.
414+ // This method uses atomic compare-and-swap to ensure all handoff state is updated atomically.
399415func (cn * Conn ) MarkForHandoff (newEndpoint string , seqID int64 ) error {
400- // Use single atomic CAS operation for state transition
401- if ! cn .shouldHandoffAtomic .CompareAndSwap (false , true ) {
402- return errors .New ("connection is already marked for handoff" )
416+ const maxRetries = 50
417+ const baseDelay = time .Microsecond
418+
419+ for attempt := 0 ; attempt < maxRetries ; attempt ++ {
420+ currentState := cn .getHandoffState ()
421+
422+ // Check if already marked for handoff
423+ if currentState .ShouldHandoff {
424+ return errors .New ("connection is already marked for handoff" )
425+ }
426+
427+ // Create new state with handoff enabled
428+ newState := & HandoffState {
429+ ShouldHandoff : true ,
430+ Endpoint : newEndpoint ,
431+ SeqID : seqID ,
432+ }
433+
434+ // Atomic compare-and-swap to update entire state
435+ if cn .handoffStateAtomic .CompareAndSwap (currentState , newState ) {
436+ return nil
437+ }
438+
439+ // If CAS failed, add exponential backoff to reduce contention
440+ if attempt < maxRetries - 1 {
441+ delay := baseDelay * time .Duration (1 << uint (attempt % 10 )) // Cap exponential growth
442+ time .Sleep (delay )
443+ }
403444 }
404445
405- cn .setNewEndpoint (newEndpoint )
406- cn .setMovingSeqID (seqID )
407- return nil
446+ return fmt .Errorf ("failed to mark connection for handoff after %d attempts due to high contention" , maxRetries )
408447}
409448
410449func (cn * Conn ) MarkQueuedForHandoff () error {
411- // Use single atomic CAS operation for state transition
412- if ! cn .shouldHandoffAtomic .CompareAndSwap (true , false ) {
413- return errors .New ("connection was not marked for handoff" )
450+ const maxRetries = 50
451+ const baseDelay = time .Microsecond
452+
453+ for attempt := 0 ; attempt < maxRetries ; attempt ++ {
454+ currentState := cn .getHandoffState ()
455+
456+ // Check if marked for handoff
457+ if ! currentState .ShouldHandoff {
458+ return errors .New ("connection was not marked for handoff" )
459+ }
460+
461+ // Create new state with handoff disabled (queued)
462+ newState := & HandoffState {
463+ ShouldHandoff : false ,
464+ Endpoint : currentState .Endpoint , // Preserve endpoint for handoff processing
465+ SeqID : currentState .SeqID , // Preserve seqID for handoff processing
466+ }
467+
468+ // Atomic compare-and-swap to update state
469+ if cn .handoffStateAtomic .CompareAndSwap (currentState , newState ) {
470+ cn .setUsable (false )
471+ return nil
472+ }
473+
474+ // If CAS failed, add exponential backoff to reduce contention
475+ if attempt < maxRetries - 1 {
476+ delay := baseDelay * time .Duration (1 << uint (attempt % 10 )) // Cap exponential growth
477+ time .Sleep (delay )
478+ }
414479 }
415- cn . setUsable ( false )
416- return nil
480+
481+ return fmt . Errorf ( "failed to mark connection as queued for handoff after %d attempts due to high contention" , maxRetries )
417482}
418483
419484// ShouldHandoff returns true if the connection needs to be handed off (lock-free).
@@ -431,17 +496,30 @@ func (cn *Conn) GetMovingSeqID() int64 {
431496 return cn .getMovingSeqID ()
432497}
433498
499+ // GetHandoffInfo returns all handoff information atomically (lock-free).
500+ // This method prevents race conditions by returning all handoff state in a single atomic operation.
501+ // Returns (shouldHandoff, endpoint, seqID).
502+ func (cn * Conn ) GetHandoffInfo () (bool , string , int64 ) {
503+ state := cn .getHandoffState ()
504+ return state .ShouldHandoff , state .Endpoint , state .SeqID
505+ }
506+
434507// GetID returns the unique identifier for this connection.
435508func (cn * Conn ) GetID () uint64 {
436509 return cn .id
437510}
438511
439512// ClearHandoffState clears the handoff state after successful handoff (lock-free).
440513func (cn * Conn ) ClearHandoffState () {
441- // clear handoff state
442- cn .setShouldHandoff (false )
443- cn .setNewEndpoint ("" )
444- cn .setMovingSeqID (0 )
514+ // Create clean state
515+ cleanState := & HandoffState {
516+ ShouldHandoff : false ,
517+ Endpoint : "" ,
518+ SeqID : 0 ,
519+ }
520+
521+ // Atomically set clean state
522+ cn .setHandoffState (cleanState )
445523 cn .setHandoffRetries (0 )
446524 cn .setUsable (true ) // Connection is safe to use again after handoff completes
447525}
0 commit comments