Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions rtx_timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
)

const (
// Recommended values per RFC 9260 section 16.
// RTO.Initial in msec.
rtoInitial float64 = 1.0 * 1000

Expand All @@ -32,10 +33,14 @@ const (
pathMaxRetrans uint = 5

noMaxRetrans uint = 0

// Clock granularity G (RFC 9260 sec 6.3.1 G1). Use a fine granularity (<100ms).
// Unit: msec.
clockGranularity float64 = 1.0
)

// rtoManager manages Rtx timeout values.
// This is an implementation of RFC 4960 sec 6.3.1.
// This is an implementation of RFC 9260 sec 6.3.1.
type rtoManager struct {
srtt float64
rttvar float64
Expand Down Expand Up @@ -74,6 +79,11 @@ func (m *rtoManager) setNewRTT(rtt float64) float64 {
} else {
// Subsequent rtt measurement
m.rttvar = (1-rtoBeta)*m.rttvar + rtoBeta*(math.Abs(m.srtt-rtt))

// RFC 9260 sec 6.3.1 G1
if m.rttvar == 0 {
m.rttvar = clockGranularity
}
m.srtt = (1-rtoAlpha)*m.srtt + rtoAlpha*rtt
}
m.rto = math.Min(math.Max(m.srtt+4*m.rttvar, rtoMin), m.rtoMax)
Expand Down Expand Up @@ -112,7 +122,7 @@ func (m *rtoManager) setRTO(rto float64, noUpdate bool) {
m.noUpdate = noUpdate
}

// rtxTimerObserver is the inteface to a timer observer.
// rtxTimerObserver is the interface to a timer observer.
// NOTE: Observers MUST NOT call start() or stop() method on rtxTimer
// from within these callbacks.
type rtxTimerObserver interface {
Expand All @@ -128,7 +138,7 @@ const (
rtxTimerClosed
)

// rtxTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1.
// rtxTimer provides the retransmission timer conforms with RFC 9260 Sec 6.3.
type rtxTimer struct {
timer *time.Timer
observer rtxTimerObserver
Expand Down Expand Up @@ -198,15 +208,38 @@ func (t *rtxTimer) start(rto float64) bool {
// fast timeout for the tests. Non-test code should pass in the
// rto generated by rtoManager getRTO() method which caps the
// value at RTO.Min or at RTO.Max.

// RFC 9260 sec 6.3.2 R1: the RTO used here SHOULD include any doubling
// due to previous expirations; we therefore preserve t.nRtos and only
// update the base RTO.
t.rto = rto
t.nRtos = 0
t.state = rtxTimerStarted
t.pending++
t.timer.Reset(t.calculateNextTimeout())

return true
}

// restart restarts the running timer using the current backoff state (R3).
// It does not modify nRtos; returns false if the timer isn't running.
func (t *rtxTimer) restart() bool {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.state != rtxTimerStarted {
return false
}

if t.timer.Stop() {
t.pending--
}

t.pending++
t.timer.Reset(t.calculateNextTimeout())

return true
}

// stop stops the timer.
func (t *rtxTimer) stop() {
t.mutex.Lock()
Expand All @@ -229,9 +262,26 @@ func (t *rtxTimer) close() {
if t.state == rtxTimerStarted && t.timer.Stop() {
t.pending--
}

t.state = rtxTimerClosed
}

// updateBaseRTO updates the base RTO (SRTT+4*RTTVAR) and collapses backoff.
// Call this after a NEW RTT measurement is incorporated (RFC 9260 sec 6.3.3 E3).
func (t *rtxTimer) updateBaseRTO(rto float64) {
t.mutex.Lock()
defer t.mutex.Unlock()
t.rto = rto
t.nRtos = 0
if t.state == rtxTimerStarted {
if t.timer.Stop() {
t.pending--
}
t.pending++
t.timer.Reset(t.calculateNextTimeout())
}
}

// isRunning tests if the timer is running.
// Debug purpose only.
func (t *rtxTimer) isRunning() bool {
Expand All @@ -242,11 +292,8 @@ func (t *rtxTimer) isRunning() bool {
}

func calculateNextTimeout(rto float64, nRtos uint, rtoMax float64) float64 {
// RFC 4096 sec 6.3.3. Handle T3-rtx Expiration
// E2) For the destination address for which the timer expires, set RTO
// <- RTO * 2 ("back off the timer"). The maximum value discussed
// in rule C7 above (RTO.max) may be used to provide an upper bound
// to this doubling operation.
// RFC 9260 sec 6.3.3 E2: On T3-rtx expiration, RTO <- 2*RTO; optionally
// bound by RTO.Max (per section 6.3.1 C7).
if nRtos < 31 {
m := 1 << nRtos

Expand Down
180 changes: 170 additions & 10 deletions rtx_timer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@ func TestRTOManager(t *testing.T) {
}
})

t.Run("G1 granularity bump when RTTVAR==0 (RFC 9260 sec 6.3.1)", func(t *testing.T) {
manager := newRTOManager(0)
manager.setNewRTT(100)

manager.mutex.Lock()
manager.srtt = 100
manager.rttvar = 0
manager.mutex.Unlock()

manager.setNewRTT(100)
manager.mutex.RLock()
defer manager.mutex.RUnlock()

assert.Equal(t, clockGranularity, manager.rttvar, "RTTVAR should be set to clock granularity G")
assert.GreaterOrEqual(t, manager.rto, rtoMin, "RTO must respect RTO.Min clamp")
})

t.Run("calculateNextTimeout", func(t *testing.T) {
var rto float64
rto = calculateNextTimeout(1.0, 0, defaultRTOMax)
Expand Down Expand Up @@ -126,7 +143,7 @@ func TestRtxTimer(t *testing.T) { //nolint:maintidx
// 60 : 2 (90)
// 120: 3 (210)
// 240: 4 (550) <== expected in 650 msec
assert.Equalf(t, timerID, id, "unexpted timer ID: %d", id)
assert.Equalf(t, timerID, id, "unexpected timer ID: %d", id)
},
onRtxFailure: func(_ int) {},
}, pathMaxRetrans, 0)
Expand All @@ -152,7 +169,7 @@ func TestRtxTimer(t *testing.T) { //nolint:maintidx
rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, _ uint) {
atomic.AddInt32(&nCbs, 1)
assert.Equalf(t, timerID, id, "unexpted timer ID: %d", id)
assert.Equalf(t, timerID, id, "unexpected timer ID: %d", id)
},
onRtxFailure: func(_ int) {},
}, pathMaxRetrans, 0)
Expand All @@ -172,14 +189,14 @@ func TestRtxTimer(t *testing.T) { //nolint:maintidx
assert.Equal(t, int32(1), atomic.LoadInt32(&nCbs), "must be called once")
})

t.Run("stop right afeter start", func(t *testing.T) {
t.Run("stop right after start", func(t *testing.T) {
timerID := 3
var nCbs int32

rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, _ uint) {
atomic.AddInt32(&nCbs, 1)
assert.Equalf(t, timerID, id, "unexpted timer ID: %d", id)
assert.Equalf(t, timerID, id, "unexpected timer ID: %d", id)
},
onRtxFailure: func(_ int) {},
}, pathMaxRetrans, 0)
Expand All @@ -202,7 +219,7 @@ func TestRtxTimer(t *testing.T) { //nolint:maintidx
rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, _ uint) {
atomic.AddInt32(&nCbs, 1)
assert.Equalf(t, timerID, id, "unexpted timer ID: %d", id)
assert.Equalf(t, timerID, id, "unexpected timer ID: %d", id)
},
onRtxFailure: func(_ int) {},
}, pathMaxRetrans, 0)
Expand All @@ -229,7 +246,7 @@ func TestRtxTimer(t *testing.T) { //nolint:maintidx
onRTO: func(id int, _ uint) {
atomic.AddInt32(&nCbs, 1)
t.Log("onRTO() called")
assert.Equalf(t, timerID, id, "unexpted timer ID: %d", id)
assert.Equalf(t, timerID, id, "unexpected timer ID: %d", id)
},
onRtxFailure: func(_ int) {},
}, pathMaxRetrans, 0)
Expand All @@ -254,12 +271,12 @@ func TestRtxTimer(t *testing.T) { //nolint:maintidx
var elapsed float64 // in seconds
rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, nRtos uint) {
assert.Equal(t, timerID, id, "unexpted timer ID: %d", id)
assert.Equal(t, timerID, id, "unexpected timer ID: %d", id)
t.Logf("onRTO: n=%d elapsed=%.03f\n", nRtos, time.Since(since).Seconds())
atomic.AddInt32(&nCbs, 1)
},
onRtxFailure: func(id int) {
assert.Equal(t, timerID, id, "unexpted timer ID: %d", id)
assert.Equal(t, timerID, id, "unexpected timer ID: %d", id)
elapsed = time.Since(since).Seconds()
t.Logf("onRtxFailure: elapsed=%.03f\n", elapsed)
doneCh <- true
Expand Down Expand Up @@ -297,7 +314,7 @@ func TestRtxTimer(t *testing.T) { //nolint:maintidx
var elapsed float64 // in seconds
rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, nRtos uint) {
assert.Equal(t, timerID, id, "unexpted timer ID: %d", id)
assert.Equal(t, timerID, id, "unexpected timer ID: %d", id)
elapsed = time.Since(since).Seconds()
t.Logf("onRTO: n=%d elapsed=%.03f\n", nRtos, elapsed)
atomic.AddInt32(&nCbs, 1)
Expand Down Expand Up @@ -339,7 +356,7 @@ func TestRtxTimer(t *testing.T) { //nolint:maintidx

rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, _ uint) {
assert.Equal(t, timerID, id, "unexpted timer ID: %d", id)
assert.Equal(t, timerID, id, "unexpected timer ID: %d", id)
doneCh <- true
},
onRtxFailure: func(_ int) {},
Expand Down Expand Up @@ -382,4 +399,147 @@ func TestRtxTimer(t *testing.T) { //nolint:maintidx
time.Sleep(100 * time.Millisecond)
assert.Equal(t, 0, rtoCount, "RTO should not occur")
})

t.Run("backoff persists across start; collapses on new RTT (RFC 9260 sec 6.3.2 R1, 6.3.3 E3)", func(t *testing.T) {
rt := newRTXTimer(7, &testTimerObserver{
onRTO: func(int, uint) {},
onRtxFailure: func(int) {},
}, 0, 60000) // 60000 is rtoMax
defer rt.close()

ok := rt.start(5000)
assert.True(t, ok, "start failed")

rt.timeout() // nRtos=1
rt.timeout() // nRtos=2
assert.Equal(t, uint(2), rt.nRtos, "expected backoff nRtos=2")

// stop then start again: backoff must PERSIST (R1).
rt.stop()
ok = rt.start(5000) // ms
assert.True(t, ok, "restart failed")
assert.Equal(t, uint(2), rt.nRtos, "backoff should persist across start")

// new RTT measurement collapses backoff (E3 note) via updateBaseRTO.
rt.updateBaseRTO(50) // ms
assert.Equal(t, uint(0), rt.nRtos, "backoff should collapse on fresh RTT")
})

t.Run("restart keeps current backoff (RFC 9260 sec 6.3.2 R3)", func(t *testing.T) {
rt := newRTXTimer(8, &testTimerObserver{
onRTO: func(int, uint) {},
onRtxFailure: func(int) {},
}, 0, 60000)
defer rt.close()

ok := rt.start(5000)
assert.True(t, ok, "start failed")

rt.timeout() // nRtos=1
before := rt.nRtos // snapshot
ok = rt.restart() // restart at current (backed-off) RTO

assert.True(t, ok, "restart failed")
assert.Equal(t, before, rt.nRtos, "restart must not change backoff")
})
}

func FuzzCalculateNextTimeout(f *testing.F) {
f.Add(100.0, uint(0), 60000.0)
f.Add(10.0, uint(10), 50.0)
f.Add(0.5, uint(64), 1000.0)

f.Fuzz(func(t *testing.T, rto float64, nRtos uint, rtoMax float64) {
if math.IsNaN(rto) || math.IsNaN(rtoMax) || math.IsInf(rto, 0) || math.IsInf(rtoMax, 0) || rto <= 0 || rtoMax <= 0 {
t.Skip()
}

if nRtos > 64 {
nRtos %= 65
}

got := calculateNextTimeout(rto, nRtos, rtoMax)

assert.False(t, math.IsNaN(got), "NaN result")
assert.False(t, math.IsInf(got, 0), "Inf result")
assert.GreaterOrEqual(t, got, 0.0, "negative timeout")
assert.LessOrEqual(t, got, rtoMax+1e-9, "exceeds RTO.Max")

if nRtos > 0 {
prev := calculateNextTimeout(rto, nRtos-1, rtoMax)
if prev < rtoMax {
assert.GreaterOrEqual(t, got, prev, "non-monotone growth before capping")
}
}
})
}

func FuzzRTOManager_SetNewRTT(f *testing.F) {
// Seeds (pairs of bytes -> RTTs)
f.Add([]byte{0, 100, 0, 120, 0, 90, 39, 16, 0, 10}) // ~[100,120,90,10000,10]
f.Add([]byte{0, 1, 0, 1, 0, 1, 0, 1}) // small RTTs
f.Add([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0xEA, 0x60}) // large-ish + zeros

f.Fuzz(func(t *testing.T, data []byte) {
if len(data) < 2 {
t.Skip()
}

m := newRTOManager(0) // uses defaultRTOMax

for i := 0; i+1 < len(data); i += 2 {
u := uint16(data[i])<<8 | uint16(data[i+1])
rtt := 1 + float64(u%60000)

m.setNewRTT(rtt)

rto := m.getRTO()
assert.False(t, math.IsNaN(rto), "NaN RTO")
assert.False(t, math.IsInf(rto, 0), "Inf RTO")
assert.GreaterOrEqual(t, rto, rtoMin-1e-9, "below RTO.Min")
assert.LessOrEqual(t, rto, m.rtoMax+1e-9, "above RTO.Max")
}
})
}

func FuzzRTXTimer_StateMachine(f *testing.F) {
// seed: start, timeout, timeout, stop, updateBaseRTO, restart, close
f.Add([]byte{0, 3, 3, 1, 4, 2, 5})

f.Fuzz(func(t *testing.T, ops []byte) {
rt := newRTXTimer(99, &testTimerObserver{
onRTO: func(int, uint) {},
onRtxFailure: func(int) {},
},
0, // unlimited
60000, // cap
)
defer rt.close()

for _, op := range ops {
switch op % 6 {
case 0: // start
_ = rt.start(5_000_000) // huge base: real timer won't fire during the fuzz iteration
case 1: // stop
rt.stop()
case 2: // restart
_ = rt.restart()
case 3: // simulate expiry
rt.timeout()
case 4: // collapse backoff on new RTT
rt.updateBaseRTO(10)
case 5: // close
rt.close()
}

// invariants after each op
assert.LessOrEqual(t, int(rt.pending), 1, "pending must be 0 or 1")

if rt.state == rtxTimerClosed {
assert.False(t, rt.isRunning(), "closed timer must not run")
ok := rt.start(10)
assert.False(t, ok, "start should fail after close")
}
}
})
}
Loading