Skip to content

Commit d7321c1

Browse files
authored
[client] The status cmd will not be blocked by the ICE probe (#4597)
The status cmd will not be blocked by the ICE probe Refactor the TURN and STUN probe, and cache the results. The NetBird status command will indicate a "checking…" state.
1 parent 404cab9 commit d7321c1

File tree

6 files changed

+216
-45
lines changed

6 files changed

+216
-45
lines changed

client/cmd/debug.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
168168

169169
client := proto.NewDaemonServiceClient(conn)
170170

171-
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
171+
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true})
172172
if err != nil {
173173
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
174174
}
@@ -303,7 +303,7 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
303303

304304
func getStatusOutput(cmd *cobra.Command, anon bool) string {
305305
var statusOutputString string
306-
statusResp, err := getStatus(cmd.Context())
306+
statusResp, err := getStatus(cmd.Context(), true)
307307
if err != nil {
308308
cmd.PrintErrf("Failed to get status: %v\n", err)
309309
} else {

client/cmd/status.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
6868

6969
ctx := internal.CtxInitState(cmd.Context())
7070

71-
resp, err := getStatus(ctx)
71+
resp, err := getStatus(ctx, false)
7272
if err != nil {
7373
return err
7474
}
@@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
121121
return nil
122122
}
123123

124-
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
124+
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
125125
conn, err := DialClientGRPCServer(ctx, daemonAddr)
126126
if err != nil {
127127
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
130130
}
131131
defer conn.Close()
132132

133-
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
133+
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
134134
if err != nil {
135135
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
136136
}

client/internal/engine.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ type Engine struct {
202202
// WireGuard interface monitor
203203
wgIfaceMonitor *WGIfaceMonitor
204204
wgIfaceMonitorWg sync.WaitGroup
205+
206+
probeStunTurn *relay.StunTurnProbe
205207
}
206208

207209
// Peer is an instance of the Connection Peer
@@ -244,6 +246,7 @@ func NewEngine(
244246
statusRecorder: statusRecorder,
245247
checks: checks,
246248
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
249+
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
247250
}
248251

249252
sm := profilemanager.NewServiceManager("")
@@ -1663,7 +1666,7 @@ func (e *Engine) getRosenpassAddr() string {
16631666

16641667
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
16651668
// and updates the status recorder with the latest states.
1666-
func (e *Engine) RunHealthProbes() bool {
1669+
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
16671670
e.syncMsgMux.Lock()
16681671

16691672
signalHealthy := e.signal.IsHealthy()
@@ -1695,8 +1698,12 @@ func (e *Engine) RunHealthProbes() bool {
16951698
}
16961699

16971700
e.syncMsgMux.Unlock()
1698-
1699-
results := e.probeICE(stuns, turns)
1701+
var results []relay.ProbeResult
1702+
if waitForResult {
1703+
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
1704+
} else {
1705+
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
1706+
}
17001707
e.statusRecorder.UpdateRelayStates(results)
17011708

17021709
relayHealthy := true
@@ -1713,13 +1720,6 @@ func (e *Engine) RunHealthProbes() bool {
17131720
return allHealthy
17141721
}
17151722

1716-
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
1717-
return append(
1718-
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
1719-
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
1720-
)
1721-
}
1722-
17231723
// restartEngine restarts the engine by cancelling the client context
17241724
func (e *Engine) restartEngine() {
17251725
e.syncMsgMux.Lock()

client/internal/relay/relay.go

Lines changed: 189 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package relay
22

33
import (
44
"context"
5+
"crypto/sha256"
6+
"errors"
57
"fmt"
68
"net"
79
"sync"
@@ -15,15 +17,180 @@ import (
1517
nbnet "github.com/netbirdio/netbird/client/net"
1618
)
1719

20+
const (
21+
DefaultCacheTTL = 20 * time.Second
22+
probeTimeout = 6 * time.Second
23+
)
24+
25+
var (
26+
ErrCheckInProgress = errors.New("probe check is already in progress")
27+
)
28+
1829
// ProbeResult holds the info about the result of a relay probe request
1930
type ProbeResult struct {
2031
URI string
2132
Err error
2233
Addr string
2334
}
2435

36+
type StunTurnProbe struct {
37+
cacheResults []ProbeResult
38+
cacheTimestamp time.Time
39+
cacheKey string
40+
cacheTTL time.Duration
41+
probeInProgress bool
42+
probeDone chan struct{}
43+
mu sync.Mutex
44+
}
45+
46+
func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe {
47+
return &StunTurnProbe{
48+
cacheTTL: cacheTTL,
49+
}
50+
}
51+
52+
func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
53+
cacheKey := generateCacheKey(stuns, turns)
54+
55+
p.mu.Lock()
56+
if p.probeInProgress {
57+
doneChan := p.probeDone
58+
p.mu.Unlock()
59+
60+
select {
61+
case <-ctx.Done():
62+
log.Debugf("Context cancelled while waiting for probe results")
63+
return createErrorResults(stuns, turns)
64+
case <-doneChan:
65+
return p.getCachedResults(cacheKey, stuns, turns)
66+
}
67+
}
68+
69+
p.probeInProgress = true
70+
probeDone := make(chan struct{})
71+
p.probeDone = probeDone
72+
p.mu.Unlock()
73+
74+
p.doProbe(ctx, stuns, turns, cacheKey)
75+
close(probeDone)
76+
77+
return p.getCachedResults(cacheKey, stuns, turns)
78+
}
79+
80+
// ProbeAll probes all given servers asynchronously and returns the results
81+
func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
82+
cacheKey := generateCacheKey(stuns, turns)
83+
84+
p.mu.Lock()
85+
86+
if results := p.checkCache(cacheKey); results != nil {
87+
p.mu.Unlock()
88+
return results
89+
}
90+
91+
if p.probeInProgress {
92+
p.mu.Unlock()
93+
return createErrorResults(stuns, turns)
94+
}
95+
96+
p.probeInProgress = true
97+
probeDone := make(chan struct{})
98+
p.probeDone = probeDone
99+
log.Infof("started new probe for STUN, TURN servers")
100+
go func() {
101+
p.doProbe(ctx, stuns, turns, cacheKey)
102+
close(probeDone)
103+
}()
104+
105+
p.mu.Unlock()
106+
107+
timer := time.NewTimer(1300 * time.Millisecond)
108+
defer timer.Stop()
109+
110+
select {
111+
case <-ctx.Done():
112+
log.Debugf("Context cancelled while waiting for probe results")
113+
return createErrorResults(stuns, turns)
114+
case <-probeDone:
115+
// when the probe is return fast, return the results right away
116+
return p.getCachedResults(cacheKey, stuns, turns)
117+
case <-timer.C:
118+
// if the probe takes longer than 1.3s, return error results to avoid blocking
119+
return createErrorResults(stuns, turns)
120+
}
121+
}
122+
123+
func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult {
124+
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
125+
age := time.Since(p.cacheTimestamp)
126+
if age < p.cacheTTL {
127+
results := append([]ProbeResult(nil), p.cacheResults...)
128+
log.Debugf("returning cached probe results (age: %v)", age)
129+
return results
130+
}
131+
}
132+
return nil
133+
}
134+
135+
func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
136+
p.mu.Lock()
137+
defer p.mu.Unlock()
138+
139+
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
140+
return append([]ProbeResult(nil), p.cacheResults...)
141+
}
142+
return createErrorResults(stuns, turns)
143+
}
144+
145+
func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) {
146+
defer func() {
147+
p.mu.Lock()
148+
p.probeInProgress = false
149+
p.mu.Unlock()
150+
}()
151+
results := make([]ProbeResult, len(stuns)+len(turns))
152+
153+
var wg sync.WaitGroup
154+
for i, uri := range stuns {
155+
wg.Add(1)
156+
go func(idx int, stunURI *stun.URI) {
157+
defer wg.Done()
158+
159+
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
160+
defer cancel()
161+
162+
results[idx].URI = stunURI.String()
163+
results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI)
164+
}(i, uri)
165+
}
166+
167+
stunOffset := len(stuns)
168+
for i, uri := range turns {
169+
wg.Add(1)
170+
go func(idx int, turnURI *stun.URI) {
171+
defer wg.Done()
172+
173+
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
174+
defer cancel()
175+
176+
results[idx].URI = turnURI.String()
177+
results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI)
178+
}(stunOffset+i, uri)
179+
}
180+
181+
wg.Wait()
182+
183+
p.mu.Lock()
184+
p.cacheResults = results
185+
p.cacheTimestamp = time.Now()
186+
p.cacheKey = cacheKey
187+
p.mu.Unlock()
188+
189+
log.Debug("Stored new probe results in cache")
190+
}
191+
25192
// ProbeSTUN tries binding to the given STUN uri and acquiring an address
26-
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
193+
func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
27194
defer func() {
28195
if probeErr != nil {
29196
log.Debugf("stun probe error from %s: %s", uri, probeErr)
@@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
83250
}
84251

85252
// ProbeTURN tries allocating a session from the given TURN URI
86-
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
253+
func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
87254
defer func() {
88255
if probeErr != nil {
89256
log.Debugf("turn probe error from %s: %s", uri, probeErr)
@@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
160327
return relayConn.LocalAddr().String(), nil
161328
}
162329

163-
// ProbeAll probes all given servers asynchronously and returns the results
164-
func ProbeAll(
165-
ctx context.Context,
166-
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
167-
relays []*stun.URI,
168-
) []ProbeResult {
169-
results := make([]ProbeResult, len(relays))
170-
171-
var wg sync.WaitGroup
172-
for i, uri := range relays {
173-
ctx, cancel := context.WithTimeout(ctx, 6*time.Second)
174-
defer cancel()
330+
func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
331+
total := len(stuns) + len(turns)
332+
results := make([]ProbeResult, total)
175333

176-
wg.Add(1)
177-
go func(res *ProbeResult, stunURI *stun.URI) {
178-
defer wg.Done()
179-
res.URI = stunURI.String()
180-
res.Addr, res.Err = fn(ctx, stunURI)
181-
}(&results[i], uri)
334+
allURIs := append(append([]*stun.URI{}, stuns...), turns...)
335+
for i, uri := range allURIs {
336+
results[i] = ProbeResult{
337+
URI: uri.String(),
338+
Err: ErrCheckInProgress,
339+
}
182340
}
183341

184-
wg.Wait()
185-
186342
return results
187343
}
344+
345+
func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string {
346+
h := sha256.New()
347+
for _, uri := range stuns {
348+
h.Write([]byte(uri.String()))
349+
}
350+
for _, uri := range turns {
351+
h.Write([]byte(uri.String()))
352+
}
353+
return fmt.Sprintf("%x", h.Sum(nil))
354+
}

client/server/server.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,10 +1057,7 @@ func (s *Server) Status(
10571057
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
10581058

10591059
if msg.GetFullPeerStatus {
1060-
if msg.ShouldRunProbes {
1061-
s.runProbes()
1062-
}
1063-
1060+
s.runProbes(msg.ShouldRunProbes)
10641061
fullStatus := s.statusRecorder.GetFullStatus()
10651062
pbFullStatus := toProtoFullStatus(fullStatus)
10661063
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
@@ -1070,7 +1067,7 @@ func (s *Server) Status(
10701067
return &statusResponse, nil
10711068
}
10721069

1073-
func (s *Server) runProbes() {
1070+
func (s *Server) runProbes(waitForProbeResult bool) {
10741071
if s.connectClient == nil {
10751072
return
10761073
}
@@ -1081,7 +1078,7 @@ func (s *Server) runProbes() {
10811078
}
10821079

10831080
if time.Since(s.lastProbe) > probeThreshold {
1084-
if engine.RunHealthProbes() {
1081+
if engine.RunHealthProbes(waitForProbeResult) {
10851082
s.lastProbe = time.Now()
10861083
}
10871084
}

0 commit comments

Comments
 (0)