From 48d67487d931b456de0d7dbaca0c8b68f5498f4c Mon Sep 17 00:00:00 2001 From: Anonymous Date: Thu, 14 Apr 2022 16:13:21 +0100 Subject: [PATCH] Implement retransmit backoff according to 4.2.4.1 --- config.go | 4 ++++ conn.go | 3 ++- conn_test.go | 1 - e2e/e2e_lossy_test.go | 16 +++++++------ handshaker.go | 52 ++++++++++++++++++++++++++++++------------- handshaker_test.go | 12 +++++----- 6 files changed, 58 insertions(+), 30 deletions(-) diff --git a/config.go b/config.go index 3cb9ab07f..5763ce530 100644 --- a/config.go +++ b/config.go @@ -57,6 +57,10 @@ type Config struct { // defaults to time.Second FlightInterval time.Duration + // DisableRetransmitBackoff can be used to the disable the backoff feature + // when sending outbound messages as specified in RFC 4347 4.2.4.1 + DisableRetransmitBackoff bool + // PSK sets the pre-shared key used by this DTLS connection // If PSK is non-nil only PSK CipherSuites will be used PSK PSKCallback diff --git a/conn.go b/conn.go index 1c670cc2c..5d0d86fdb 100644 --- a/conn.go +++ b/conn.go @@ -202,7 +202,8 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo rootCAs: config.RootCAs, clientCAs: config.ClientCAs, customCipherSuites: config.CustomCipherSuites, - retransmitInterval: workerInterval, + initialRetransmitInterval: workerInterval, + disableRetransmitBackoff: config.DisableRetransmitBackoff, log: conn.log, initialEpoch: 0, keyLogWriter: config.KeyLogWriter, diff --git a/conn_test.go b/conn_test.go index 960118327..3b30b23fc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3319,7 +3319,6 @@ func TestApplicationDataQueueLimited(t *testing.T) { if qlen > maxAppDataPacketQueueSize { t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets)) } - t.Log(qlen) time.Sleep(1 * time.Second) } }() diff --git a/e2e/e2e_lossy_test.go b/e2e/e2e_lossy_test.go index f646557cd..3e03037ba 100644 --- a/e2e/e2e_lossy_test.go +++ b/e2e/e2e_lossy_test.go @@ -135,10 +135,11 @@ func TestPionE2ELossy(t *testing.T) { go func() { cfg := &dtls.Config{ - FlightInterval: flightInterval, - CipherSuites: test.CipherSuites, - InsecureSkipVerify: true, - MTU: test.MTU, + FlightInterval: flightInterval, + CipherSuites: test.CipherSuites, + InsecureSkipVerify: true, + MTU: test.MTU, + DisableRetransmitBackoff: true, } if test.DoClientAuth { @@ -151,9 +152,10 @@ func TestPionE2ELossy(t *testing.T) { go func() { cfg := &dtls.Config{ - Certificates: []tls.Certificate{serverCert}, - FlightInterval: flightInterval, - MTU: test.MTU, + Certificates: []tls.Certificate{serverCert}, + FlightInterval: flightInterval, + MTU: test.MTU, + DisableRetransmitBackoff: true, } if test.DoClientAuth { diff --git a/handshaker.go b/handshaker.go index 09c6b4e3b..4e0f1ad95 100644 --- a/handshaker.go +++ b/handshaker.go @@ -82,13 +82,14 @@ func (s handshakeState) String() string { } type handshakeFSM struct { - currentFlight flightVal - flights []*packet - retransmit bool - state *State - cache *handshakeCache - cfg *handshakeConfig - closed chan struct{} + currentFlight flightVal + flights []*packet + retransmit bool + retransmitInterval time.Duration + state *State + cache *handshakeCache + cfg *handshakeConfig + closed chan struct{} } type handshakeConfig struct { @@ -109,7 +110,8 @@ type handshakeConfig struct { sessionStore SessionStore rootCAs *x509.CertPool clientCAs *x509.CertPool - retransmitInterval time.Duration + initialRetransmitInterval time.Duration + disableRetransmitBackoff bool customCipherSuites func() []CipherSuite ellipticCurves []elliptic.Curve insecureSkipHelloVerify bool @@ -165,11 +167,12 @@ func newHandshakeFSM( initialFlight flightVal, ) *handshakeFSM { return &handshakeFSM{ - currentFlight: initialFlight, - state: s, - cache: cache, - cfg: cfg, - closed: make(chan struct{}), + currentFlight: initialFlight, + state: s, + cache: cache, + cfg: cfg, + retransmitInterval: cfg.initialRetransmitInterval, + closed: make(chan struct{}), } } @@ -274,11 +277,12 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, return handshakeErrored, errFlight } - retransmitTimer := time.NewTimer(s.cfg.retransmitInterval) + retransmitTimer := time.NewTimer(s.retransmitInterval) for { select { case done := <-c.recvHandshake(): nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg) + s.retransmitInterval = s.cfg.initialRetransmitInterval close(done) if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { @@ -304,8 +308,19 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, if !s.retransmit { return handshakeWaiting, nil } + + // RFC 4347 4.2.4.1: + // Implementations SHOULD use an initial timer value of 1 second (the minimum defined in RFC 2988 [RFC2988]) + // and double the value at each retransmission, up to no less than the RFC 2988 maximum of 60 seconds. + if !s.cfg.disableRetransmitBackoff { + s.retransmitInterval *= 2 + } + if s.retransmitInterval > time.Second*60 { + s.retransmitInterval = time.Second * 60 + } return handshakeSending, nil case <-ctx.Done(): + s.retransmitInterval = s.cfg.initialRetransmitInterval return handshakeErrored, ctx.Err() } } @@ -320,11 +335,12 @@ func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState return handshakeErrored, errFlight } - retransmitTimer := time.NewTimer(s.cfg.retransmitInterval) + retransmitTimer := time.NewTimer(s.retransmitInterval) select { case done := <-c.recvHandshake(): nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg) close(done) + s.retransmitInterval = s.cfg.initialRetransmitInterval if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err != nil { @@ -342,10 +358,16 @@ func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState return handshakeFinished, nil } <-retransmitTimer.C + // RFC 4347 4.2.4.1 + s.retransmitInterval *= 2 + if s.retransmitInterval > time.Second*60 { + s.retransmitInterval = time.Second * 60 + } // Retransmit last flight return handshakeSending, nil case <-ctx.Done(): + s.retransmitInterval = s.cfg.initialRetransmitInterval return handshakeErrored, ctx.Err() } return handshakeFinished, nil diff --git a/handshaker_test.go b/handshaker_test.go index 9bbca6f50..b2bd4443d 100644 --- a/handshaker_test.go +++ b/handshaker_test.go @@ -216,10 +216,10 @@ func TestHandshaker(t *testing.T) { } report := func(t *testing.T) { - // with one second server delay and 100 ms retransmit, there should be close to 10 `Finished` from client - // using a range of 9 - 11 for checking - if cntClientFinished < 8 || cntClientFinished > 11 { - t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 9, 11, cntClientFinished) + // with one second server delay and 100 ms retransmit (+ exponential backoff), there should be close to 4 `Finished` from client + // using a range of 3 - 5 for checking + if cntClientFinished < 3 || cntClientFinished > 5 { + t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 3, 5, cntClientFinished) } if !isClientFinished { t.Errorf("Client is not finished") @@ -281,7 +281,7 @@ func TestHandshaker(t *testing.T) { }) } }, - retransmitInterval: nonZeroRetransmitInterval, + initialRetransmitInterval: nonZeroRetransmitInterval, } fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1) @@ -314,7 +314,7 @@ func TestHandshaker(t *testing.T) { }) } }, - retransmitInterval: nonZeroRetransmitInterval, + initialRetransmitInterval: nonZeroRetransmitInterval, } fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0)