Skip to content

Commit 3e7c7ab

Browse files
committed
combine hitless flags in pool.conn
1 parent 0930948 commit 3e7c7ab

File tree

2 files changed

+132
-47
lines changed

2 files changed

+132
-47
lines changed

hitless/handoff_worker.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package hitless
22

33
import (
44
"context"
5+
"errors"
56
"net"
67
"sync"
78
"sync/atomic"
@@ -235,12 +236,18 @@ func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
235236
// queueHandoff queues a handoff request for processing
236237
// if err is returned, connection will be removed from pool
237238
func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
238-
// Create handoff request
239+
// Get handoff info atomically to prevent race conditions
240+
shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
241+
if !shouldHandoff {
242+
return errors.New("connection is not marked for handoff")
243+
}
244+
245+
// Create handoff request with atomically retrieved data
239246
request := HandoffRequest{
240247
Conn: conn,
241248
ConnID: conn.GetID(),
242-
Endpoint: conn.GetHandoffEndpoint(),
243-
SeqID: conn.GetMovingSeqID(),
249+
Endpoint: endpoint,
250+
SeqID: seqID,
244251
Pool: hwm.poolHook.pool, // Include pool for connection removal on failure
245252
}
246253

internal/pool/conn.go

Lines changed: 122 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ var noDeadline = time.Time{}
1818
// Global atomic counter for connection IDs
1919
var connIDCounter uint64
2020

21+
// HandoffState represents the atomic state for connection handoffs
22+
// This struct is stored atomically to prevent race conditions between
23+
// checking handoff status and reading handoff parameters
24+
type HandoffState struct {
25+
ShouldHandoff bool // Whether connection should be handed off
26+
Endpoint string // New endpoint for handoff
27+
SeqID int64 // Sequence ID from MOVING notification
28+
}
29+
2130
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
2231
type atomicNetConn struct {
2332
conn net.Conn
@@ -69,11 +78,11 @@ type Conn struct {
6978

7079
// Handoff state - using atomic operations for lock-free access
7180
usableAtomic atomic.Bool // Connection usability state
72-
shouldHandoffAtomic atomic.Bool // Whether connection should be handed off
73-
movingSeqIDAtomic atomic.Int64 // Sequence ID from MOVING notification
7481
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
75-
// newEndpointAtomic needs special handling as it's a string
76-
newEndpointAtomic atomic.Value // stores string
82+
83+
// Atomic handoff state to prevent race conditions
84+
// Stores *HandoffState to ensure atomic updates of all handoff-related fields
85+
handoffStateAtomic atomic.Value // stores *HandoffState
7786

7887
onClose func() error
7988
}
@@ -104,12 +113,17 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
104113
// Store netConn atomically for lock-free access using wrapper
105114
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
106115

107-
// Initialize atomic handoff state
116+
// Initialize atomic state
108117
cn.usableAtomic.Store(false) // false initially, set to true after initialization
109-
cn.shouldHandoffAtomic.Store(false) // false initially
110-
cn.movingSeqIDAtomic.Store(0) // 0 initially
111118
cn.handoffRetriesAtomic.Store(0) // 0 initially
112-
cn.newEndpointAtomic.Store("") // empty string initially
119+
120+
// Initialize handoff state atomically
121+
initialHandoffState := &HandoffState{
122+
ShouldHandoff: false,
123+
Endpoint: "",
124+
SeqID: 0,
125+
}
126+
cn.handoffStateAtomic.Store(initialHandoffState)
113127

114128
cn.wr = proto.NewWriter(cn.bw)
115129
cn.SetUsedAt(time.Now())
@@ -154,37 +168,38 @@ func (cn *Conn) setUsable(usable bool) {
154168
cn.usableAtomic.Store(usable)
155169
}
156170

157-
// shouldHandoff returns true if connection needs handoff (lock-free).
158-
func (cn *Conn) shouldHandoff() bool {
159-
return cn.shouldHandoffAtomic.Load()
171+
// getHandoffState returns the current handoff state atomically (lock-free).
172+
func (cn *Conn) getHandoffState() *HandoffState {
173+
state := cn.handoffStateAtomic.Load()
174+
if state == nil {
175+
// Return default state if not initialized
176+
return &HandoffState{
177+
ShouldHandoff: false,
178+
Endpoint: "",
179+
SeqID: 0,
180+
}
181+
}
182+
return state.(*HandoffState)
160183
}
161184

162-
// setShouldHandoff sets the handoff flag atomically (lock-free).
163-
func (cn *Conn) setShouldHandoff(should bool) {
164-
cn.shouldHandoffAtomic.Store(should)
185+
// setHandoffState sets the handoff state atomically (lock-free).
186+
func (cn *Conn) setHandoffState(state *HandoffState) {
187+
cn.handoffStateAtomic.Store(state)
165188
}
166189

167-
// getMovingSeqID returns the sequence ID atomically (lock-free).
168-
func (cn *Conn) getMovingSeqID() int64 {
169-
return cn.movingSeqIDAtomic.Load()
190+
// shouldHandoff returns true if connection needs handoff (lock-free).
191+
func (cn *Conn) shouldHandoff() bool {
192+
return cn.getHandoffState().ShouldHandoff
170193
}
171194

172-
// setMovingSeqID sets the sequence ID atomically (lock-free).
173-
func (cn *Conn) setMovingSeqID(seqID int64) {
174-
cn.movingSeqIDAtomic.Store(seqID)
195+
// getMovingSeqID returns the sequence ID atomically (lock-free).
196+
func (cn *Conn) getMovingSeqID() int64 {
197+
return cn.getHandoffState().SeqID
175198
}
176199

177200
// getNewEndpoint returns the new endpoint atomically (lock-free).
178201
func (cn *Conn) getNewEndpoint() string {
179-
if endpoint := cn.newEndpointAtomic.Load(); endpoint != nil {
180-
return endpoint.(string)
181-
}
182-
return ""
183-
}
184-
185-
// setNewEndpoint sets the new endpoint atomically (lock-free).
186-
func (cn *Conn) setNewEndpoint(endpoint string) {
187-
cn.newEndpointAtomic.Store(endpoint)
202+
return cn.getHandoffState().Endpoint
188203
}
189204

190205
// setHandoffRetries sets the retry count atomically (lock-free).
@@ -396,24 +411,74 @@ func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) err
396411

397412
// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free).
398413
// Returns an error if the connection is already marked for handoff.
414+
// This method uses atomic compare-and-swap to ensure all handoff state is updated atomically.
399415
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
400-
// Use single atomic CAS operation for state transition
401-
if !cn.shouldHandoffAtomic.CompareAndSwap(false, true) {
402-
return errors.New("connection is already marked for handoff")
416+
const maxRetries = 50
417+
const baseDelay = time.Microsecond
418+
419+
for attempt := 0; attempt < maxRetries; attempt++ {
420+
currentState := cn.getHandoffState()
421+
422+
// Check if already marked for handoff
423+
if currentState.ShouldHandoff {
424+
return errors.New("connection is already marked for handoff")
425+
}
426+
427+
// Create new state with handoff enabled
428+
newState := &HandoffState{
429+
ShouldHandoff: true,
430+
Endpoint: newEndpoint,
431+
SeqID: seqID,
432+
}
433+
434+
// Atomic compare-and-swap to update entire state
435+
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
436+
return nil
437+
}
438+
439+
// If CAS failed, add exponential backoff to reduce contention
440+
if attempt < maxRetries-1 {
441+
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
442+
time.Sleep(delay)
443+
}
403444
}
404445

405-
cn.setNewEndpoint(newEndpoint)
406-
cn.setMovingSeqID(seqID)
407-
return nil
446+
return fmt.Errorf("failed to mark connection for handoff after %d attempts due to high contention", maxRetries)
408447
}
409448

410449
func (cn *Conn) MarkQueuedForHandoff() error {
411-
// Use single atomic CAS operation for state transition
412-
if !cn.shouldHandoffAtomic.CompareAndSwap(true, false) {
413-
return errors.New("connection was not marked for handoff")
450+
const maxRetries = 50
451+
const baseDelay = time.Microsecond
452+
453+
for attempt := 0; attempt < maxRetries; attempt++ {
454+
currentState := cn.getHandoffState()
455+
456+
// Check if marked for handoff
457+
if !currentState.ShouldHandoff {
458+
return errors.New("connection was not marked for handoff")
459+
}
460+
461+
// Create new state with handoff disabled (queued)
462+
newState := &HandoffState{
463+
ShouldHandoff: false,
464+
Endpoint: currentState.Endpoint, // Preserve endpoint for handoff processing
465+
SeqID: currentState.SeqID, // Preserve seqID for handoff processing
466+
}
467+
468+
// Atomic compare-and-swap to update state
469+
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
470+
cn.setUsable(false)
471+
return nil
472+
}
473+
474+
// If CAS failed, add exponential backoff to reduce contention
475+
if attempt < maxRetries-1 {
476+
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
477+
time.Sleep(delay)
478+
}
414479
}
415-
cn.setUsable(false)
416-
return nil
480+
481+
return fmt.Errorf("failed to mark connection as queued for handoff after %d attempts due to high contention", maxRetries)
417482
}
418483

419484
// ShouldHandoff returns true if the connection needs to be handed off (lock-free).
@@ -431,17 +496,30 @@ func (cn *Conn) GetMovingSeqID() int64 {
431496
return cn.getMovingSeqID()
432497
}
433498

499+
// GetHandoffInfo returns all handoff information atomically (lock-free).
500+
// This method prevents race conditions by returning all handoff state in a single atomic operation.
501+
// Returns (shouldHandoff, endpoint, seqID).
502+
func (cn *Conn) GetHandoffInfo() (bool, string, int64) {
503+
state := cn.getHandoffState()
504+
return state.ShouldHandoff, state.Endpoint, state.SeqID
505+
}
506+
434507
// GetID returns the unique identifier for this connection.
435508
func (cn *Conn) GetID() uint64 {
436509
return cn.id
437510
}
438511

439512
// ClearHandoffState clears the handoff state after successful handoff (lock-free).
440513
func (cn *Conn) ClearHandoffState() {
441-
// clear handoff state
442-
cn.setShouldHandoff(false)
443-
cn.setNewEndpoint("")
444-
cn.setMovingSeqID(0)
514+
// Create clean state
515+
cleanState := &HandoffState{
516+
ShouldHandoff: false,
517+
Endpoint: "",
518+
SeqID: 0,
519+
}
520+
521+
// Atomically set clean state
522+
cn.setHandoffState(cleanState)
445523
cn.setHandoffRetries(0)
446524
cn.setUsable(true) // Connection is safe to use again after handoff completes
447525
}

0 commit comments

Comments
 (0)