Skip to content

Commit e4ae7e4

Browse files
committed
Implement client side for TCP allocations
1 parent 7abfa3b commit e4ae7e4

File tree

13 files changed

+917
-116
lines changed

13 files changed

+917
-116
lines changed

client.go

Lines changed: 129 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ type Client struct {
6262
trMap *client.TransactionMap // thread-safe
6363
rto time.Duration // read-only
6464
relayedConn *client.UDPConn // protected by mutex ***
65+
tcpAllocation *client.TCPAllocation // protected by mutex ***
6566
allocTryLock client.TryLock // thread-safe
6667
listenTryLock client.TryLock // thread-safe
6768
net transport.Net // read-only
@@ -238,42 +239,34 @@ func (c *Client) SendBindingRequest() (net.Addr, error) {
238239
return c.SendBindingRequestTo(c.stunServ)
239240
}
240241

241-
// Allocate sends a TURN allocation request to the given transport address
242-
func (c *Client) Allocate() (net.PacketConn, error) {
243-
if err := c.allocTryLock.Lock(); err != nil {
244-
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
245-
}
246-
defer c.allocTryLock.Unlock()
247-
248-
relayedConn := c.relayedUDPConn()
249-
if relayedConn != nil {
250-
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String())
251-
}
242+
func (c *Client) sendAllocateRequest(protocol proto.Protocol) (proto.RelayedAddress, proto.Lifetime, stun.Nonce, error) {
243+
var relayed proto.RelayedAddress
244+
var lifetime proto.Lifetime
245+
var nonce stun.Nonce
252246

253247
msg, err := stun.Build(
254248
stun.TransactionID,
255249
stun.NewType(stun.MethodAllocate, stun.ClassRequest),
256-
proto.RequestedTransport{Protocol: proto.ProtoUDP},
250+
proto.RequestedTransport{Protocol: protocol},
257251
stun.Fingerprint,
258252
)
259253
if err != nil {
260-
return nil, err
254+
return relayed, lifetime, nonce, err
261255
}
262256

263257
trRes, err := c.PerformTransaction(msg, c.turnServ, false)
264258
if err != nil {
265-
return nil, err
259+
return relayed, lifetime, nonce, err
266260
}
267261

268262
res := trRes.Msg
269263

270264
// Anonymous allocate failed, trying to authenticate.
271-
var nonce stun.Nonce
272265
if err = nonce.GetFrom(res); err != nil {
273-
return nil, err
266+
return relayed, lifetime, nonce, err
274267
}
275268
if err = c.realm.GetFrom(res); err != nil {
276-
return nil, err
269+
return relayed, lifetime, nonce, err
277270
}
278271
c.realm = append([]byte(nil), c.realm...)
279272
c.integrity = stun.NewLongTermIntegrity(
@@ -283,48 +276,101 @@ func (c *Client) Allocate() (net.PacketConn, error) {
283276
msg, err = stun.Build(
284277
stun.TransactionID,
285278
stun.NewType(stun.MethodAllocate, stun.ClassRequest),
286-
proto.RequestedTransport{Protocol: proto.ProtoUDP},
279+
proto.RequestedTransport{Protocol: protocol},
287280
&c.username,
288281
&c.realm,
289282
&nonce,
290283
&c.integrity,
291284
stun.Fingerprint,
292285
)
293286
if err != nil {
294-
return nil, err
287+
return relayed, lifetime, nonce, err
295288
}
296289

297290
trRes, err = c.PerformTransaction(msg, c.turnServ, false)
298291
if err != nil {
299-
return nil, err
292+
return relayed, lifetime, nonce, err
300293
}
301294
res = trRes.Msg
302295

303296
if res.Type.Class == stun.ClassErrorResponse {
304297
var code stun.ErrorCodeAttribute
305298
if err = code.GetFrom(res); err == nil {
306-
return nil, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
299+
return relayed, lifetime, nonce, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
307300
}
308-
return nil, fmt.Errorf("%s", res.Type) //nolint:goerr113
301+
return relayed, lifetime, nonce, fmt.Errorf("%s", res.Type) //nolint:goerr113
309302
}
310303

311304
// Getting relayed addresses from response.
312-
var relayed proto.RelayedAddress
313305
if err := relayed.GetFrom(res); err != nil {
306+
return relayed, lifetime, nonce, err
307+
}
308+
309+
// Getting lifetime from response
310+
if err := lifetime.GetFrom(res); err != nil {
311+
return relayed, lifetime, nonce, err
312+
}
313+
return relayed, lifetime, nonce, nil
314+
}
315+
316+
// Allocate sends a TURN allocation request to the given transport address
317+
func (c *Client) Allocate() (net.PacketConn, error) {
318+
if err := c.allocTryLock.Lock(); err != nil {
319+
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
320+
}
321+
defer c.allocTryLock.Unlock()
322+
323+
relayedConn := c.relayedUDPConn()
324+
if relayedConn != nil {
325+
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String())
326+
}
327+
328+
relayed, lifetime, nonce, err := c.sendAllocateRequest(proto.ProtoUDP)
329+
if err != nil {
314330
return nil, err
315331
}
332+
316333
relayedAddr := &net.UDPAddr{
317334
IP: relayed.IP,
318335
Port: relayed.Port,
319336
}
320337

321-
// Getting lifetime from response
322-
var lifetime proto.Lifetime
323-
if err := lifetime.GetFrom(res); err != nil {
338+
relayedConn = client.NewUDPConn(&client.ConnConfig{
339+
Observer: c,
340+
RelayedAddr: relayedAddr,
341+
Integrity: c.integrity,
342+
Nonce: nonce,
343+
Lifetime: lifetime.Duration,
344+
Log: c.log,
345+
})
346+
c.setRelayedUDPConn(relayedConn)
347+
348+
return relayedConn, nil
349+
}
350+
351+
// Allocate TCP
352+
func (c *Client) AllocateTCP() (*client.TCPAllocation, error) {
353+
if err := c.allocTryLock.Lock(); err != nil {
354+
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
355+
}
356+
defer c.allocTryLock.Unlock()
357+
358+
allocation := c.getTCPAllocation()
359+
if allocation != nil {
360+
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, allocation.Addr().String())
361+
}
362+
363+
relayed, lifetime, nonce, err := c.sendAllocateRequest(proto.ProtoTCP)
364+
if err != nil {
324365
return nil, err
325366
}
326367

327-
relayedConn = client.NewUDPConn(&client.UDPConnConfig{
368+
relayedAddr := &net.TCPAddr{
369+
IP: relayed.IP,
370+
Port: relayed.Port,
371+
}
372+
373+
allocation = client.NewTCPAllocation(&client.ConnConfig{
328374
Observer: c,
329375
RelayedAddr: relayedAddr,
330376
Integrity: c.integrity,
@@ -333,15 +379,26 @@ func (c *Client) Allocate() (net.PacketConn, error) {
333379
Log: c.log,
334380
})
335381

336-
c.setRelayedUDPConn(relayedConn)
382+
c.setTCPAllocation(allocation)
337383

338-
return relayedConn, nil
384+
return allocation, nil
339385
}
340386

341387
// CreatePermission Issues a CreatePermission request for the supplied addresses
342388
// as described in https://datatracker.ietf.org/doc/html/rfc5766#section-9
343389
func (c *Client) CreatePermission(addrs ...net.Addr) error {
344-
return c.relayedUDPConn().CreatePermissions(addrs...)
390+
if conn := c.relayedUDPConn(); conn != nil {
391+
if err := conn.CreatePermissions(addrs...); err != nil {
392+
return err
393+
}
394+
}
395+
396+
if allocation := c.getTCPAllocation(); allocation != nil {
397+
if err := allocation.CreatePermissions(addrs...); err != nil {
398+
return err
399+
}
400+
}
401+
return nil
345402
}
346403

347404
// PerformTransaction performs STUN transaction
@@ -387,6 +444,7 @@ func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, ignoreResult
387444
// (Called by UDPConn)
388445
func (c *Client) OnDeallocated(relayedAddr net.Addr) {
389446
c.setRelayedUDPConn(nil)
447+
c.setTCPAllocation(nil)
390448
}
391449

392450
// HandleInbound handles data received.
@@ -445,7 +503,8 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
445503
}
446504

447505
if msg.Type.Class == stun.ClassIndication {
448-
if msg.Type.Method == stun.MethodData {
506+
switch msg.Type.Method {
507+
case stun.MethodData:
449508
var peerAddr proto.PeerAddress
450509
if err := peerAddr.GetFrom(msg); err != nil {
451510
return err
@@ -467,8 +526,32 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
467526
c.log.Debug("no relayed conn allocated")
468527
return nil // silently discard
469528
}
470-
471529
relayedConn.HandleInbound(data, from)
530+
case stun.MethodConnectionAttempt:
531+
var peerAddr proto.PeerAddress
532+
if err := peerAddr.GetFrom(msg); err != nil {
533+
return err
534+
}
535+
536+
addr := &net.TCPAddr{
537+
IP: peerAddr.IP,
538+
Port: peerAddr.Port,
539+
}
540+
541+
var cid proto.ConnectionID
542+
if err := cid.GetFrom(msg); err != nil {
543+
return err
544+
}
545+
546+
c.log.Debugf("connection attempt from %s", addr.String())
547+
548+
allocation := c.getTCPAllocation()
549+
if allocation == nil {
550+
c.log.Debug("no TCP allocation exists")
551+
return nil // silently discard
552+
}
553+
554+
allocation.HandleConnectionAttempt(addr, cid)
472555
}
473556
return nil
474557
}
@@ -579,3 +662,17 @@ func (c *Client) relayedUDPConn() *client.UDPConn {
579662

580663
return c.relayedConn
581664
}
665+
666+
func (c *Client) setTCPAllocation(alloc *client.TCPAllocation) {
667+
c.mutex.Lock()
668+
defer c.mutex.Unlock()
669+
670+
c.tcpAllocation = alloc
671+
}
672+
673+
func (c *Client) getTCPAllocation() *client.TCPAllocation {
674+
c.mutex.RLock()
675+
defer c.mutex.RUnlock()
676+
677+
return c.tcpAllocation
678+
}

client_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/pion/logging"
1212
"github.com/pion/transport/v2/stdnet"
1313
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
1415
)
1516

1617
func createListeningTestClient(t *testing.T, loggerFactory logging.LoggerFactory) (*Client, net.PacketConn, bool) {
@@ -187,3 +188,53 @@ func TestClientNonceExpiration(t *testing.T) {
187188
assert.NoError(t, conn.Close())
188189
assert.NoError(t, server.Close())
189190
}
191+
192+
// Create a TCP-based allocation and verify allocation can be created
193+
func TestTCPClient(t *testing.T) {
194+
// Setup server
195+
tcpListener, err := net.Listen("tcp4", "0.0.0.0:13478")
196+
require.NoError(t, err)
197+
198+
server, err := NewServer(ServerConfig{
199+
AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) {
200+
return GenerateAuthKey(username, realm, "pass"), true
201+
},
202+
ListenerConfigs: []ListenerConfig{
203+
{
204+
Listener: tcpListener,
205+
RelayAddressGenerator: &RelayAddressGeneratorStatic{
206+
RelayAddress: net.ParseIP("127.0.0.1"),
207+
Address: "0.0.0.0",
208+
},
209+
},
210+
},
211+
Realm: "pion.ly",
212+
})
213+
require.NoError(t, err)
214+
215+
// Setup clients
216+
conn, err := net.Dial("tcp", "127.0.0.1:13478")
217+
require.NoError(t, err)
218+
219+
client, err := NewClient(&ClientConfig{
220+
Conn: NewSTUNConn(conn),
221+
STUNServerAddr: "127.0.0.1:13478",
222+
TURNServerAddr: "127.0.0.1:13478",
223+
Username: "foo",
224+
Password: "pass",
225+
})
226+
require.NoError(t, err)
227+
require.NoError(t, client.Listen())
228+
229+
allocation, err := client.AllocateTCP()
230+
require.NoError(t, err)
231+
232+
// TODO: Implement server side handling of Connect and ConnectionBind
233+
// _, err = allocation.Dial(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080})
234+
// assert.NoError(t, err)
235+
236+
// Shutdown
237+
require.NoError(t, allocation.Close())
238+
require.NoError(t, conn.Close())
239+
require.NoError(t, server.Close())
240+
}

0 commit comments

Comments
 (0)