@@ -2,6 +2,8 @@ package relay
22
33import (
44 "context"
5+ "crypto/sha256"
6+ "errors"
57 "fmt"
68 "net"
79 "sync"
@@ -15,15 +17,180 @@ import (
1517 nbnet "github.com/netbirdio/netbird/client/net"
1618)
1719
20+ const (
21+ DefaultCacheTTL = 20 * time .Second
22+ probeTimeout = 6 * time .Second
23+ )
24+
25+ var (
26+ ErrCheckInProgress = errors .New ("probe check is already in progress" )
27+ )
28+
1829// ProbeResult holds the info about the result of a relay probe request
1930type ProbeResult struct {
2031 URI string
2132 Err error
2233 Addr string
2334}
2435
36+ type StunTurnProbe struct {
37+ cacheResults []ProbeResult
38+ cacheTimestamp time.Time
39+ cacheKey string
40+ cacheTTL time.Duration
41+ probeInProgress bool
42+ probeDone chan struct {}
43+ mu sync.Mutex
44+ }
45+
46+ func NewStunTurnProbe (cacheTTL time.Duration ) * StunTurnProbe {
47+ return & StunTurnProbe {
48+ cacheTTL : cacheTTL ,
49+ }
50+ }
51+
52+ func (p * StunTurnProbe ) ProbeAllWaitResult (ctx context.Context , stuns []* stun.URI , turns []* stun.URI ) []ProbeResult {
53+ cacheKey := generateCacheKey (stuns , turns )
54+
55+ p .mu .Lock ()
56+ if p .probeInProgress {
57+ doneChan := p .probeDone
58+ p .mu .Unlock ()
59+
60+ select {
61+ case <- ctx .Done ():
62+ log .Debugf ("Context cancelled while waiting for probe results" )
63+ return createErrorResults (stuns , turns )
64+ case <- doneChan :
65+ return p .getCachedResults (cacheKey , stuns , turns )
66+ }
67+ }
68+
69+ p .probeInProgress = true
70+ probeDone := make (chan struct {})
71+ p .probeDone = probeDone
72+ p .mu .Unlock ()
73+
74+ p .doProbe (ctx , stuns , turns , cacheKey )
75+ close (probeDone )
76+
77+ return p .getCachedResults (cacheKey , stuns , turns )
78+ }
79+
80+ // ProbeAll probes all given servers asynchronously and returns the results
81+ func (p * StunTurnProbe ) ProbeAll (ctx context.Context , stuns []* stun.URI , turns []* stun.URI ) []ProbeResult {
82+ cacheKey := generateCacheKey (stuns , turns )
83+
84+ p .mu .Lock ()
85+
86+ if results := p .checkCache (cacheKey ); results != nil {
87+ p .mu .Unlock ()
88+ return results
89+ }
90+
91+ if p .probeInProgress {
92+ p .mu .Unlock ()
93+ return createErrorResults (stuns , turns )
94+ }
95+
96+ p .probeInProgress = true
97+ probeDone := make (chan struct {})
98+ p .probeDone = probeDone
99+ log .Infof ("started new probe for STUN, TURN servers" )
100+ go func () {
101+ p .doProbe (ctx , stuns , turns , cacheKey )
102+ close (probeDone )
103+ }()
104+
105+ p .mu .Unlock ()
106+
107+ timer := time .NewTimer (1300 * time .Millisecond )
108+ defer timer .Stop ()
109+
110+ select {
111+ case <- ctx .Done ():
112+ log .Debugf ("Context cancelled while waiting for probe results" )
113+ return createErrorResults (stuns , turns )
114+ case <- probeDone :
115+ // when the probe is return fast, return the results right away
116+ return p .getCachedResults (cacheKey , stuns , turns )
117+ case <- timer .C :
118+ // if the probe takes longer than 1.3s, return error results to avoid blocking
119+ return createErrorResults (stuns , turns )
120+ }
121+ }
122+
123+ func (p * StunTurnProbe ) checkCache (cacheKey string ) []ProbeResult {
124+ if p .cacheKey == cacheKey && len (p .cacheResults ) > 0 {
125+ age := time .Since (p .cacheTimestamp )
126+ if age < p .cacheTTL {
127+ results := append ([]ProbeResult (nil ), p .cacheResults ... )
128+ log .Debugf ("returning cached probe results (age: %v)" , age )
129+ return results
130+ }
131+ }
132+ return nil
133+ }
134+
135+ func (p * StunTurnProbe ) getCachedResults (cacheKey string , stuns []* stun.URI , turns []* stun.URI ) []ProbeResult {
136+ p .mu .Lock ()
137+ defer p .mu .Unlock ()
138+
139+ if p .cacheKey == cacheKey && len (p .cacheResults ) > 0 {
140+ return append ([]ProbeResult (nil ), p .cacheResults ... )
141+ }
142+ return createErrorResults (stuns , turns )
143+ }
144+
145+ func (p * StunTurnProbe ) doProbe (ctx context.Context , stuns []* stun.URI , turns []* stun.URI , cacheKey string ) {
146+ defer func () {
147+ p .mu .Lock ()
148+ p .probeInProgress = false
149+ p .mu .Unlock ()
150+ }()
151+ results := make ([]ProbeResult , len (stuns )+ len (turns ))
152+
153+ var wg sync.WaitGroup
154+ for i , uri := range stuns {
155+ wg .Add (1 )
156+ go func (idx int , stunURI * stun.URI ) {
157+ defer wg .Done ()
158+
159+ probeCtx , cancel := context .WithTimeout (ctx , probeTimeout )
160+ defer cancel ()
161+
162+ results [idx ].URI = stunURI .String ()
163+ results [idx ].Addr , results [idx ].Err = p .probeSTUN (probeCtx , stunURI )
164+ }(i , uri )
165+ }
166+
167+ stunOffset := len (stuns )
168+ for i , uri := range turns {
169+ wg .Add (1 )
170+ go func (idx int , turnURI * stun.URI ) {
171+ defer wg .Done ()
172+
173+ probeCtx , cancel := context .WithTimeout (ctx , probeTimeout )
174+ defer cancel ()
175+
176+ results [idx ].URI = turnURI .String ()
177+ results [idx ].Addr , results [idx ].Err = p .probeTURN (probeCtx , turnURI )
178+ }(stunOffset + i , uri )
179+ }
180+
181+ wg .Wait ()
182+
183+ p .mu .Lock ()
184+ p .cacheResults = results
185+ p .cacheTimestamp = time .Now ()
186+ p .cacheKey = cacheKey
187+ p .mu .Unlock ()
188+
189+ log .Debug ("Stored new probe results in cache" )
190+ }
191+
25192// ProbeSTUN tries binding to the given STUN uri and acquiring an address
26- func ProbeSTUN (ctx context.Context , uri * stun.URI ) (addr string , probeErr error ) {
193+ func ( p * StunTurnProbe ) probeSTUN (ctx context.Context , uri * stun.URI ) (addr string , probeErr error ) {
27194 defer func () {
28195 if probeErr != nil {
29196 log .Debugf ("stun probe error from %s: %s" , uri , probeErr )
@@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
83250}
84251
85252// ProbeTURN tries allocating a session from the given TURN URI
86- func ProbeTURN (ctx context.Context , uri * stun.URI ) (addr string , probeErr error ) {
253+ func ( p * StunTurnProbe ) probeTURN (ctx context.Context , uri * stun.URI ) (addr string , probeErr error ) {
87254 defer func () {
88255 if probeErr != nil {
89256 log .Debugf ("turn probe error from %s: %s" , uri , probeErr )
@@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
160327 return relayConn .LocalAddr ().String (), nil
161328}
162329
163- // ProbeAll probes all given servers asynchronously and returns the results
164- func ProbeAll (
165- ctx context.Context ,
166- fn func (ctx context.Context , uri * stun.URI ) (addr string , probeErr error ),
167- relays []* stun.URI ,
168- ) []ProbeResult {
169- results := make ([]ProbeResult , len (relays ))
170-
171- var wg sync.WaitGroup
172- for i , uri := range relays {
173- ctx , cancel := context .WithTimeout (ctx , 6 * time .Second )
174- defer cancel ()
330+ func createErrorResults (stuns []* stun.URI , turns []* stun.URI ) []ProbeResult {
331+ total := len (stuns ) + len (turns )
332+ results := make ([]ProbeResult , total )
175333
176- wg . Add ( 1 )
177- go func ( res * ProbeResult , stunURI * stun. URI ) {
178- defer wg . Done ()
179- res . URI = stunURI .String ()
180- res . Addr , res . Err = fn ( ctx , stunURI )
181- }( & results [ i ], uri )
334+ allURIs := append ( append ([] * stun. URI {}, stuns ... ), turns ... )
335+ for i , uri := range allURIs {
336+ results [ i ] = ProbeResult {
337+ URI : uri .String (),
338+ Err : ErrCheckInProgress ,
339+ }
182340 }
183341
184- wg .Wait ()
185-
186342 return results
187343}
344+
345+ func generateCacheKey (stuns []* stun.URI , turns []* stun.URI ) string {
346+ h := sha256 .New ()
347+ for _ , uri := range stuns {
348+ h .Write ([]byte (uri .String ()))
349+ }
350+ for _ , uri := range turns {
351+ h .Write ([]byte (uri .String ()))
352+ }
353+ return fmt .Sprintf ("%x" , h .Sum (nil ))
354+ }
0 commit comments