Skip to content

Commit

Permalink
dkg: strengthen logic for steps update (#2797)
Browse files Browse the repository at this point in the history
Improved logic for steps update in DKG process.

category: misc
ticket: none
  • Loading branch information
pinebit authored Jan 19, 2024
1 parent 3bac84f commit 226bdfc
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 24 deletions.
5 changes: 3 additions & 2 deletions dkg/dkg.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,9 @@ func startSyncProtocol(ctx context.Context, tcpNode host.Host, key *k1.PrivateKe
break
}

// Sleep for 100ms to let clients connect with each other.
time.Sleep(time.Millisecond * 100)
// Sleep for 250ms to let clients connect with each other.
// Must be at least two times greater than the sync messages period specified in client.go NewClient().
time.Sleep(time.Millisecond * 250)
}

// Disable reconnecting clients to other peer's server once all clients are connected.
Expand Down
16 changes: 8 additions & 8 deletions dkg/sync/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func NewClient(tcpNode host.Host, peer peer.ID, hashSig []byte, version version.
done: make(chan struct{}),
reconnect: true,
version: version,
period: 250 * time.Millisecond,
period: 100 * time.Millisecond, // Must be at least two times lower than the sync timeout (dkg.go, startSyncProtocol)
}

for _, opt := range opts {
Expand All @@ -53,7 +53,7 @@ func NewClient(tcpNode host.Host, peer peer.ID, hashSig []byte, version version.
// supports reestablishing on relay circuit recycling, and supports soft shutdown.
type Client struct {
// Mutable state
mu sync.Mutex
mu sync.RWMutex
connected bool
reconnect bool
step int
Expand Down Expand Up @@ -113,16 +113,16 @@ func (c *Client) SetStep(step int) {

// getStep returns the current step.
func (c *Client) getStep() int {
c.mu.Lock()
defer c.mu.Unlock()
c.mu.RLock()
defer c.mu.RUnlock()

return c.step
}

// IsConnected returns if client is connected to the server or not.
func (c *Client) IsConnected() bool {
c.mu.Lock()
defer c.mu.Unlock()
c.mu.RLock()
defer c.mu.RUnlock()

return c.connected
}
Expand Down Expand Up @@ -255,8 +255,8 @@ func (c *Client) DisableReconnect() {

// shouldReconnect returns true if clients should re-attempt connecting to peers.
func (c *Client) shouldReconnect() bool {
c.mu.Lock()
defer c.mu.Unlock()
c.mu.RLock()
defer c.mu.RUnlock()

return c.reconnect
}
47 changes: 33 additions & 14 deletions dkg/sync/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ type Server struct {
allCount int // Excluding self

// Mutable state
mu sync.Mutex
mu sync.RWMutex
shutdown map[peer.ID]struct{}
connected map[peer.ID]struct{}
steps map[peer.ID]int
Expand Down Expand Up @@ -95,8 +95,8 @@ func (s *Server) setErr(err error) {

// Err returns the shared error state for the server.
func (s *Server) Err() error {
s.mu.Lock()
defer s.mu.Unlock()
s.mu.RLock()
defer s.mu.RUnlock()

return s.err
}
Expand Down Expand Up @@ -144,8 +144,8 @@ func (s *Server) AwaitAllAtStep(ctx context.Context, step int) error {

// isConnected returns the shared connected state for the peer.
func (s *Server) isConnected(pID peer.ID) bool {
s.mu.Lock()
defer s.mu.Unlock()
s.mu.RLock()
defer s.mu.RUnlock()

_, ok := s.connected[pID]

Expand All @@ -170,26 +170,43 @@ func (s *Server) setShutdown(pID peer.ID) {
s.shutdown[pID] = struct{}{}
}

// setStep sets the peer's reported step.
func (s *Server) setStep(pID peer.ID, step int) {
// updateStep updates the peer's step from the reported value.
// Returns error if the reported step is not the same or monotonically increased.
func (s *Server) updateStep(pID peer.ID, step int) error {
s.mu.Lock()
defer s.mu.Unlock()

currentPeerStep, hasCurrentPeerStep := s.steps[pID]

if hasCurrentPeerStep && step < currentPeerStep {
return errors.New("peer reported step is behind the last known step", z.Int("peer_step", step), z.Int("last_step", currentPeerStep))
}

if hasCurrentPeerStep && step > currentPeerStep+1 {
return errors.New("peer reported step is ahead the last known step", z.Int("peer_step", step), z.Int("last_step", currentPeerStep))
}

if !hasCurrentPeerStep && (step < 0 || step > 1) {
return errors.New("peer reported abnormal initial step, expected 0 or 1", z.Int("peer_step", step))
}

s.steps[pID] = step

return nil
}

// isAllConnected returns if all expected peers are connected.
func (s *Server) isAllConnected() bool {
s.mu.Lock()
defer s.mu.Unlock()
s.mu.RLock()
defer s.mu.RUnlock()

return len(s.connected) == s.allCount
}

// isAllShutdown returns if all expected peers are shutdown.
func (s *Server) isAllShutdown() bool {
s.mu.Lock()
defer s.mu.Unlock()
s.mu.RLock()
defer s.mu.RUnlock()

return len(s.shutdown) == s.allCount
}
Expand All @@ -199,8 +216,8 @@ func (s *Server) isAllShutdown() bool {
// so one peer will always increment first putting it ahead of the others. At least we know all peers
// are or were at the given step.
func (s *Server) isAllAtStep(step int) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.mu.RLock()
defer s.mu.RUnlock()

if len(s.steps) != s.allCount {
return false, nil
Expand Down Expand Up @@ -262,7 +279,9 @@ func (s *Server) handleStream(ctx context.Context, stream network.Stream) error
log.Info(ctx, fmt.Sprintf("Connected to peer %d of %d", count, s.allCount))
}

s.setStep(pID, int(msg.Step))
if err := s.updateStep(pID, int(msg.Step)); err != nil {
return err
}

// Write response message
if err := writeSizedProto(stream, resp); err != nil {
Expand Down
60 changes: 60 additions & 0 deletions dkg/sync/server_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright © 2022-2023 Obol Labs Inc. Licensed under the terms of a Business Source License 1.1

package sync

import (
"testing"

"github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"

"github.com/obolnetwork/charon/app/version"
"github.com/obolnetwork/charon/testutil"
)

func TestUpdateStep(t *testing.T) {
sv, err := version.Parse("v0.1")
require.NoError(t, err)

server := &Server{
defHash: testutil.RandomBytes32(),
tcpNode: nil,
allCount: 1,
shutdown: make(map[peer.ID]struct{}),
connected: make(map[peer.ID]struct{}),
steps: make(map[peer.ID]int),
version: sv,
}

t.Run("wrong initial step", func(t *testing.T) {
err = server.updateStep("alpha", 100)
require.ErrorContains(t, err, "peer reported abnormal initial step, expected 0 or 1")
})

t.Run("valid peer step update", func(t *testing.T) {
err = server.updateStep("bravo", 1)
require.NoError(t, err)

err = server.updateStep("bravo", 1)
require.NoError(t, err) // same step is allowed

err = server.updateStep("bravo", 2)
require.NoError(t, err) // next step is allowed
})

t.Run("peer step is behind", func(t *testing.T) {
err = server.updateStep("behind", 1)
require.NoError(t, err)

err = server.updateStep("behind", 0)
require.ErrorContains(t, err, "peer reported step is behind the last known step")
})

t.Run("peer step is ahead", func(t *testing.T) {
err = server.updateStep("ahead", 1)
require.NoError(t, err)

err = server.updateStep("ahead", 3)
require.ErrorContains(t, err, "peer reported step is ahead the last known step")
})
}

0 comments on commit 226bdfc

Please sign in to comment.