Skip to content

Commit

Permalink
[client] Add peer conn init limit (netbirdio#3001)
Browse files Browse the repository at this point in the history
Limit the peer connection initialization to 200 peers at the same time
  • Loading branch information
mlsmaycon authored Dec 9, 2024
1 parent e40a29b commit 2147bf7
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 7 deletions.
6 changes: 5 additions & 1 deletion client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"

nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
Expand All @@ -62,6 +63,7 @@ import (
const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
)

var ErrResetConnection = fmt.Errorf("reset connection")
Expand Down Expand Up @@ -177,6 +179,7 @@ type Engine struct {
// Network map persistence
persistNetworkMap bool
latestNetworkMap *mgmProto.NetworkMap
connSemaphore *semaphoregroup.SemaphoreGroup
}

// Peer is an instance of the Connection Peer
Expand Down Expand Up @@ -242,6 +245,7 @@ func NewEngineWithProbes(
statusRecorder: statusRecorder,
probes: probes,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
}
if runtime.GOOS == "ios" {
if !fileExists(mobileDep.StateFilePath) {
Expand Down Expand Up @@ -1051,7 +1055,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
},
}

peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher)
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore)
if err != nil {
return nil, err
}
Expand Down
9 changes: 7 additions & 2 deletions client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)

type ConnPriority int
Expand Down Expand Up @@ -104,12 +105,13 @@ type Conn struct {
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy

guard *guard.Guard
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
}

// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) {
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
if err != nil {
log.Errorf("failed to parse allowedIPS: %v", err)
Expand All @@ -130,6 +132,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
allowedIP: allowedIP,
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
semaphore: semaphore,
}

rFns := WorkerRelayCallbacks{
Expand Down Expand Up @@ -169,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open() {
conn.semaphore.Add(conn.ctx)
conn.log.Debugf("open connection to peer")

conn.mu.Lock()
Expand All @@ -191,6 +195,7 @@ func (conn *Conn) Open() {
}

func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
defer conn.semaphore.Done(conn.ctx)
conn.waitInitialRandomSleepTime(ctx)

err := conn.handshaker.sendOffer()
Expand Down
9 changes: 5 additions & 4 deletions client/internal/peer/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/util"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)

var connConf = ConnConfig{
Expand Down Expand Up @@ -46,7 +47,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {

func TestConn_GetKey(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil {
return
}
Expand All @@ -58,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {

func TestConn_OnRemoteOffer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil {
return
}
Expand Down Expand Up @@ -92,7 +93,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {

func TestConn_OnRemoteAnswer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil {
return
}
Expand Down Expand Up @@ -125,7 +126,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
}
func TestConn_Status(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil {
return
}
Expand Down
48 changes: 48 additions & 0 deletions util/semaphore-group/semaphore_group.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package semaphoregroup

import (
"context"
"sync"
)

// SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore.
type SemaphoreGroup struct {
waitGroup sync.WaitGroup
semaphore chan struct{}
}

// NewSemaphoreGroup creates a new SemaphoreGroup with the specified semaphore limit.
func NewSemaphoreGroup(limit int) *SemaphoreGroup {
return &SemaphoreGroup{
semaphore: make(chan struct{}, limit),
}
}

// Add increments the internal WaitGroup counter and acquires a semaphore slot.
func (sg *SemaphoreGroup) Add(ctx context.Context) {
sg.waitGroup.Add(1)

// Acquire semaphore slot
select {
case <-ctx.Done():
return
case sg.semaphore <- struct{}{}:
}
}

// Done decrements the internal WaitGroup counter and releases a semaphore slot.
func (sg *SemaphoreGroup) Done(ctx context.Context) {
sg.waitGroup.Done()

// Release semaphore slot
select {
case <-ctx.Done():
return
case <-sg.semaphore:
}
}

// Wait waits until the internal WaitGroup counter is zero.
func (sg *SemaphoreGroup) Wait() {
sg.waitGroup.Wait()
}
66 changes: 66 additions & 0 deletions util/semaphore-group/semaphore_group_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package semaphoregroup

import (
"context"
"testing"
"time"
)

func TestSemaphoreGroup(t *testing.T) {
semGroup := NewSemaphoreGroup(2)

for i := 0; i < 5; i++ {
semGroup.Add(context.Background())
go func(id int) {
defer semGroup.Done(context.Background())

got := len(semGroup.semaphore)
if got == 0 {
t.Errorf("Expected semaphore length > 0 , got 0")
}

time.Sleep(time.Millisecond)
t.Logf("Goroutine %d is running\n", id)
}(i)
}

semGroup.Wait()

want := 0
got := len(semGroup.semaphore)
if got != want {
t.Errorf("Expected semaphore length %d, got %d", want, got)
}
}

func TestSemaphoreGroupContext(t *testing.T) {
semGroup := NewSemaphoreGroup(1)
semGroup.Add(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
t.Cleanup(cancel)
rChan := make(chan struct{})

go func() {
semGroup.Add(ctx)
rChan <- struct{}{}
}()
select {
case <-rChan:
case <-time.NewTimer(2 * time.Second).C:
t.Error("Adding to semaphore group should not block when context is not done")
}

semGroup.Done(context.Background())

ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second)
t.Cleanup(cancelDone)
go func() {
semGroup.Done(ctxDone)
rChan <- struct{}{}
}()
select {
case <-rChan:
case <-time.NewTimer(2 * time.Second).C:
t.Error("Releasing from semaphore group should not block when context is not done")
}
}

0 comments on commit 2147bf7

Please sign in to comment.