From 226bdfc0c9b17613ccf56698bbab39dc9846beb8 Mon Sep 17 00:00:00 2001 From: Andrei Smirnov Date: Fri, 19 Jan 2024 11:14:02 +0300 Subject: [PATCH] dkg: strengthen logic for steps update (#2797) Improved logic for steps update in DKG process. category: misc ticket: none --- dkg/dkg.go | 5 +-- dkg/sync/client.go | 16 ++++----- dkg/sync/server.go | 47 +++++++++++++++++-------- dkg/sync/server_internal_test.go | 60 ++++++++++++++++++++++++++++++++ 4 files changed, 104 insertions(+), 24 deletions(-) create mode 100644 dkg/sync/server_internal_test.go diff --git a/dkg/dkg.go b/dkg/dkg.go index f1ea91b9a..dbbdcf494 100644 --- a/dkg/dkg.go +++ b/dkg/dkg.go @@ -478,8 +478,9 @@ func startSyncProtocol(ctx context.Context, tcpNode host.Host, key *k1.PrivateKe break } - // Sleep for 100ms to let clients connect with each other. - time.Sleep(time.Millisecond * 100) + // Sleep for 250ms to let clients connect with each other. + // Must be at least two times greater than the sync messages period specified in client.go NewClient(). + time.Sleep(time.Millisecond * 250) } // Disable reconnecting clients to other peer's server once all clients are connected. diff --git a/dkg/sync/client.go b/dkg/sync/client.go index d810a5dc9..2507f76e2 100644 --- a/dkg/sync/client.go +++ b/dkg/sync/client.go @@ -38,7 +38,7 @@ func NewClient(tcpNode host.Host, peer peer.ID, hashSig []byte, version version. done: make(chan struct{}), reconnect: true, version: version, - period: 250 * time.Millisecond, + period: 100 * time.Millisecond, // Must be at least two times lower than the sync timeout (dkg.go, startSyncProtocol) } for _, opt := range opts { @@ -53,7 +53,7 @@ func NewClient(tcpNode host.Host, peer peer.ID, hashSig []byte, version version. // supports reestablishing on relay circuit recycling, and supports soft shutdown. type Client struct { // Mutable state - mu sync.Mutex + mu sync.RWMutex connected bool reconnect bool step int @@ -113,16 +113,16 @@ func (c *Client) SetStep(step int) { // getStep returns the current step. func (c *Client) getStep() int { - c.mu.Lock() - defer c.mu.Unlock() + c.mu.RLock() + defer c.mu.RUnlock() return c.step } // IsConnected returns if client is connected to the server or not. func (c *Client) IsConnected() bool { - c.mu.Lock() - defer c.mu.Unlock() + c.mu.RLock() + defer c.mu.RUnlock() return c.connected } @@ -255,8 +255,8 @@ func (c *Client) DisableReconnect() { // shouldReconnect returns true if clients should re-attempt connecting to peers. func (c *Client) shouldReconnect() bool { - c.mu.Lock() - defer c.mu.Unlock() + c.mu.RLock() + defer c.mu.RUnlock() return c.reconnect } diff --git a/dkg/sync/server.go b/dkg/sync/server.go index 7496ade84..5341ca3df 100644 --- a/dkg/sync/server.go +++ b/dkg/sync/server.go @@ -57,7 +57,7 @@ type Server struct { allCount int // Excluding self // Mutable state - mu sync.Mutex + mu sync.RWMutex shutdown map[peer.ID]struct{} connected map[peer.ID]struct{} steps map[peer.ID]int @@ -95,8 +95,8 @@ func (s *Server) setErr(err error) { // Err returns the shared error state for the server. func (s *Server) Err() error { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() return s.err } @@ -144,8 +144,8 @@ func (s *Server) AwaitAllAtStep(ctx context.Context, step int) error { // isConnected returns the shared connected state for the peer. func (s *Server) isConnected(pID peer.ID) bool { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() _, ok := s.connected[pID] @@ -170,26 +170,43 @@ func (s *Server) setShutdown(pID peer.ID) { s.shutdown[pID] = struct{}{} } -// setStep sets the peer's reported step. -func (s *Server) setStep(pID peer.ID, step int) { +// updateStep updates the peer's step from the reported value. +// Returns error if the reported step is not the same or monotonically increased. +func (s *Server) updateStep(pID peer.ID, step int) error { s.mu.Lock() defer s.mu.Unlock() + currentPeerStep, hasCurrentPeerStep := s.steps[pID] + + if hasCurrentPeerStep && step < currentPeerStep { + return errors.New("peer reported step is behind the last known step", z.Int("peer_step", step), z.Int("last_step", currentPeerStep)) + } + + if hasCurrentPeerStep && step > currentPeerStep+1 { + return errors.New("peer reported step is ahead the last known step", z.Int("peer_step", step), z.Int("last_step", currentPeerStep)) + } + + if !hasCurrentPeerStep && (step < 0 || step > 1) { + return errors.New("peer reported abnormal initial step, expected 0 or 1", z.Int("peer_step", step)) + } + s.steps[pID] = step + + return nil } // isAllConnected returns if all expected peers are connected. func (s *Server) isAllConnected() bool { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() return len(s.connected) == s.allCount } // isAllShutdown returns if all expected peers are shutdown. func (s *Server) isAllShutdown() bool { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() return len(s.shutdown) == s.allCount } @@ -199,8 +216,8 @@ func (s *Server) isAllShutdown() bool { // so one peer will always increment first putting it ahead of the others. At least we know all peers // are or were at the given step. func (s *Server) isAllAtStep(step int) (bool, error) { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() if len(s.steps) != s.allCount { return false, nil @@ -262,7 +279,9 @@ func (s *Server) handleStream(ctx context.Context, stream network.Stream) error log.Info(ctx, fmt.Sprintf("Connected to peer %d of %d", count, s.allCount)) } - s.setStep(pID, int(msg.Step)) + if err := s.updateStep(pID, int(msg.Step)); err != nil { + return err + } // Write response message if err := writeSizedProto(stream, resp); err != nil { diff --git a/dkg/sync/server_internal_test.go b/dkg/sync/server_internal_test.go new file mode 100644 index 000000000..abddfa3e4 --- /dev/null +++ b/dkg/sync/server_internal_test.go @@ -0,0 +1,60 @@ +// Copyright © 2022-2023 Obol Labs Inc. Licensed under the terms of a Business Source License 1.1 + +package sync + +import ( + "testing" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/require" + + "github.com/obolnetwork/charon/app/version" + "github.com/obolnetwork/charon/testutil" +) + +func TestUpdateStep(t *testing.T) { + sv, err := version.Parse("v0.1") + require.NoError(t, err) + + server := &Server{ + defHash: testutil.RandomBytes32(), + tcpNode: nil, + allCount: 1, + shutdown: make(map[peer.ID]struct{}), + connected: make(map[peer.ID]struct{}), + steps: make(map[peer.ID]int), + version: sv, + } + + t.Run("wrong initial step", func(t *testing.T) { + err = server.updateStep("alpha", 100) + require.ErrorContains(t, err, "peer reported abnormal initial step, expected 0 or 1") + }) + + t.Run("valid peer step update", func(t *testing.T) { + err = server.updateStep("bravo", 1) + require.NoError(t, err) + + err = server.updateStep("bravo", 1) + require.NoError(t, err) // same step is allowed + + err = server.updateStep("bravo", 2) + require.NoError(t, err) // next step is allowed + }) + + t.Run("peer step is behind", func(t *testing.T) { + err = server.updateStep("behind", 1) + require.NoError(t, err) + + err = server.updateStep("behind", 0) + require.ErrorContains(t, err, "peer reported step is behind the last known step") + }) + + t.Run("peer step is ahead", func(t *testing.T) { + err = server.updateStep("ahead", 1) + require.NoError(t, err) + + err = server.updateStep("ahead", 3) + require.ErrorContains(t, err, "peer reported step is ahead the last known step") + }) +}