Skip to content

Commit

Permalink
Implement retransmit backoff according to 4.2.4.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Anonymous authored and Sean-Der committed Jul 2, 2024
1 parent 45e16a0 commit 48d6748
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 30 deletions.
4 changes: 4 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ type Config struct {
// defaults to time.Second
FlightInterval time.Duration

// DisableRetransmitBackoff can be used to the disable the backoff feature
// when sending outbound messages as specified in RFC 4347 4.2.4.1
DisableRetransmitBackoff bool

// PSK sets the pre-shared key used by this DTLS connection
// If PSK is non-nil only PSK CipherSuites will be used
PSK PSKCallback
Expand Down
3 changes: 2 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
rootCAs: config.RootCAs,
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
initialRetransmitInterval: workerInterval,
disableRetransmitBackoff: config.DisableRetransmitBackoff,
log: conn.log,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
Expand Down
1 change: 0 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3319,7 +3319,6 @@ func TestApplicationDataQueueLimited(t *testing.T) {
if qlen > maxAppDataPacketQueueSize {
t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets))
}
t.Log(qlen)
time.Sleep(1 * time.Second)
}
}()
Expand Down
16 changes: 9 additions & 7 deletions e2e/e2e_lossy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,11 @@ func TestPionE2ELossy(t *testing.T) {

go func() {
cfg := &dtls.Config{
FlightInterval: flightInterval,
CipherSuites: test.CipherSuites,
InsecureSkipVerify: true,
MTU: test.MTU,
FlightInterval: flightInterval,
CipherSuites: test.CipherSuites,
InsecureSkipVerify: true,
MTU: test.MTU,
DisableRetransmitBackoff: true,
}

if test.DoClientAuth {
Expand All @@ -151,9 +152,10 @@ func TestPionE2ELossy(t *testing.T) {

go func() {
cfg := &dtls.Config{
Certificates: []tls.Certificate{serverCert},
FlightInterval: flightInterval,
MTU: test.MTU,
Certificates: []tls.Certificate{serverCert},
FlightInterval: flightInterval,
MTU: test.MTU,
DisableRetransmitBackoff: true,
}

if test.DoClientAuth {
Expand Down
52 changes: 37 additions & 15 deletions handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ func (s handshakeState) String() string {
}

type handshakeFSM struct {
currentFlight flightVal
flights []*packet
retransmit bool
state *State
cache *handshakeCache
cfg *handshakeConfig
closed chan struct{}
currentFlight flightVal
flights []*packet
retransmit bool
retransmitInterval time.Duration
state *State
cache *handshakeCache
cfg *handshakeConfig
closed chan struct{}
}

type handshakeConfig struct {
Expand All @@ -109,7 +110,8 @@ type handshakeConfig struct {
sessionStore SessionStore
rootCAs *x509.CertPool
clientCAs *x509.CertPool
retransmitInterval time.Duration
initialRetransmitInterval time.Duration
disableRetransmitBackoff bool
customCipherSuites func() []CipherSuite
ellipticCurves []elliptic.Curve
insecureSkipHelloVerify bool
Expand Down Expand Up @@ -165,11 +167,12 @@ func newHandshakeFSM(
initialFlight flightVal,
) *handshakeFSM {
return &handshakeFSM{
currentFlight: initialFlight,
state: s,
cache: cache,
cfg: cfg,
closed: make(chan struct{}),
currentFlight: initialFlight,
state: s,
cache: cache,
cfg: cfg,
retransmitInterval: cfg.initialRetransmitInterval,
closed: make(chan struct{}),
}
}

Expand Down Expand Up @@ -274,11 +277,12 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState,
return handshakeErrored, errFlight
}

retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
retransmitTimer := time.NewTimer(s.retransmitInterval)
for {
select {
case done := <-c.recvHandshake():
nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
s.retransmitInterval = s.cfg.initialRetransmitInterval
close(done)
if alert != nil {
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
Expand All @@ -304,8 +308,19 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState,
if !s.retransmit {
return handshakeWaiting, nil
}

// RFC 4347 4.2.4.1:
// Implementations SHOULD use an initial timer value of 1 second (the minimum defined in RFC 2988 [RFC2988])
// and double the value at each retransmission, up to no less than the RFC 2988 maximum of 60 seconds.
if !s.cfg.disableRetransmitBackoff {
s.retransmitInterval *= 2
}
if s.retransmitInterval > time.Second*60 {
s.retransmitInterval = time.Second * 60
}
return handshakeSending, nil
case <-ctx.Done():
s.retransmitInterval = s.cfg.initialRetransmitInterval
return handshakeErrored, ctx.Err()
}
}
Expand All @@ -320,11 +335,12 @@ func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState
return handshakeErrored, errFlight
}

retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
retransmitTimer := time.NewTimer(s.retransmitInterval)
select {
case done := <-c.recvHandshake():
nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
close(done)
s.retransmitInterval = s.cfg.initialRetransmitInterval
if alert != nil {
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
if err != nil {
Expand All @@ -342,10 +358,16 @@ func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState
return handshakeFinished, nil
}
<-retransmitTimer.C
// RFC 4347 4.2.4.1
s.retransmitInterval *= 2
if s.retransmitInterval > time.Second*60 {
s.retransmitInterval = time.Second * 60
}
// Retransmit last flight
return handshakeSending, nil

case <-ctx.Done():
s.retransmitInterval = s.cfg.initialRetransmitInterval
return handshakeErrored, ctx.Err()
}
return handshakeFinished, nil
Expand Down
12 changes: 6 additions & 6 deletions handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ func TestHandshaker(t *testing.T) {
}

report := func(t *testing.T) {
// with one second server delay and 100 ms retransmit, there should be close to 10 `Finished` from client
// using a range of 9 - 11 for checking
if cntClientFinished < 8 || cntClientFinished > 11 {
t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 9, 11, cntClientFinished)
// with one second server delay and 100 ms retransmit (+ exponential backoff), there should be close to 4 `Finished` from client
// using a range of 3 - 5 for checking
if cntClientFinished < 3 || cntClientFinished > 5 {
t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 3, 5, cntClientFinished)
}
if !isClientFinished {
t.Errorf("Client is not finished")
Expand Down Expand Up @@ -281,7 +281,7 @@ func TestHandshaker(t *testing.T) {
})
}
},
retransmitInterval: nonZeroRetransmitInterval,
initialRetransmitInterval: nonZeroRetransmitInterval,
}

fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1)
Expand Down Expand Up @@ -314,7 +314,7 @@ func TestHandshaker(t *testing.T) {
})
}
},
retransmitInterval: nonZeroRetransmitInterval,
initialRetransmitInterval: nonZeroRetransmitInterval,
}

fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0)
Expand Down

0 comments on commit 48d6748

Please sign in to comment.