Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 156 additions & 3 deletions internal/allocation/allocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
package allocation

import (
"bytes"
"encoding/gob"
"net"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -40,13 +43,23 @@ type Allocation struct {
log logging.LeveledLogger

tcpConnections map[proto.ConnectionID]net.Conn // Guarded by AllocationManager lock
expiresAt time.Time

// Some clients (Firefox or others using resiprocate's nICE lib) may retry allocation
// with same 5 tuple when received 413, for compatible with these clients,
// cache for response lost and client retry to implement 'stateless stack approach'
// See: https://datatracker.ietf.org/doc/html/rfc5766#section-6.2
responseCache atomic.Value // *allocationResponse
}
type serializedAllocation struct {
RelayAdd string
Protocol Protocol
FiveTuple []byte
Permissions [][]byte
ChannelBindings [][]byte
Username, Realm string
ExpiresAt time.Time
}

// NewAllocation creates a new instance of NewAllocation.
func NewAllocation(
Expand All @@ -66,6 +79,133 @@ func NewAllocation(
}
}

func (a *Allocation) serialize() *serializedAllocation {
var serialized serializedAllocation
var err error
if serialized.FiveTuple, err = a.fiveTuple.MarshalBinary(); err != nil {
a.log.Errorf("failed to marshal FiveTuple: %v", err)
}
serialized.Permissions = make([][]byte, 0, len(a.permissions))
var data []byte
for _, p := range a.permissions {
data, err = p.MarshalBinary()
if err != nil {
a.log.Errorf("failed to marshal Permission: %v", err)

return nil
}
serialized.Permissions = append(serialized.Permissions, data)
}
serialized.ChannelBindings = make([][]byte, 0, len(a.channelBindings))
for _, cb := range a.channelBindings {
data, err = cb.MarshalBinary()
if err != nil {
a.log.Errorf("failed to marshal ChannelBind: %v", err)

return nil
}
serialized.ChannelBindings = append(serialized.ChannelBindings, data)
}
serialized.Realm = a.realm
serialized.Username = a.username
serialized.Protocol = a.Protocol
serialized.RelayAdd = a.RelayAddr.String()
serialized.ExpiresAt = a.expiresAt
a.tcpConnections = make(map[proto.ConnectionID]net.Conn)

return &serialized
}

func (a *Allocation) deserialize(serialized *serializedAllocation) {
if err := a.fiveTuple.UnmarshalBinary(serialized.FiveTuple); err != nil {
a.log.Errorf("failed to unmarshal FiveTuple: %v", err)

return
}

if err := a.deserializePermissions(serialized.Permissions); err != nil {
a.log.Errorf("failed to unmarshal permissions: %v", err)
}

if err := a.deserializeChannelBindings(serialized.ChannelBindings); err != nil {
a.log.Errorf("failed to unmarshal channel bindings: %v", err)
}

a.realm = serialized.Realm
a.username = serialized.Username
a.Protocol = serialized.Protocol
a.setRelayAddr(serialized.Protocol, serialized.RelayAdd)
a.expiresAt = serialized.ExpiresAt
remaningTime := time.Until(a.expiresAt)
if remaningTime > 0 {
if a.lifetimeTimer != nil {
a.lifetimeTimer.Reset(remaningTime)
}
}
}

func (a *Allocation) setRelayAddr(protocol Protocol, relayAdd string) {
network := strings.ToLower(protocol.String())
switch protocol {
case UDP:
a.RelayAddr, _ = net.ResolveUDPAddr(network, relayAdd)
case TCP:
a.RelayAddr, _ = net.ResolveTCPAddr(network, relayAdd)
default:
a.log.Errorf("%s %v", errUnsupportedProtocol, protocol)
}
}

func (a *Allocation) deserializePermissions(permissions [][]byte) error {
a.permissions = make(map[string]*Permission, 64)
for _, p := range permissions {
perm := &Permission{}
if err := perm.UnmarshalBinary(p); err != nil {
return err
}
a.permissions[ipnet.FingerprintAddr(perm.Addr)] = perm
}

return nil
}

func (a *Allocation) deserializeChannelBindings(channelBindings [][]byte) error {
a.channelBindings = make([]*ChannelBind, 0, 64)
for _, cb := range channelBindings {
channelBind := &ChannelBind{}
if err := channelBind.UnmarshalBinary(cb); err != nil {
return err
}
a.channelBindings = append(a.channelBindings, channelBind)
}

return nil
}

func (a *Allocation) MarshalBinary() ([]byte, error) {
var buf bytes.Buffer

enc := gob.NewEncoder(&buf)

serialized := a.serialize()
if err := enc.Encode(*serialized); err != nil {
return nil, err
}

return buf.Bytes(), nil
}

func (a *Allocation) UnmarshalBinary(data []byte) error {
var serialized serializedAllocation
dec := gob.NewDecoder(bytes.NewBuffer(data))
if err := dec.Decode(&serialized); err != nil {
return err
}
a.deserialize(&serialized)

return nil
}

// GetPermission gets the Permission from the allocation.
func (a *Allocation) GetPermission(addr net.Addr) *Permission {
a.permissionsLock.RLock()
Expand Down Expand Up @@ -229,11 +369,24 @@ func (a *Allocation) ListChannelBindings() []*ChannelBind {

// Refresh updates the allocations lifetime.
func (a *Allocation) Refresh(lifetime time.Duration) {
a.expiresAt = time.Now().Add(lifetime)
if !a.lifetimeTimer.Reset(lifetime) {
a.log.Errorf("Failed to reset allocation timer for %v", a.fiveTuple)
}
}

func (a *Allocation) Stop() {
if a.lifetimeTimer == nil {
a.log.Errorf("Allocation timer was nil for %v", a.fiveTuple)

return
}
if !a.lifetimeTimer.Stop() {
a.log.Warnf("Allocation timer for %v had already fired when Stop() was called", a.fiveTuple)
}
a.expiresAt = time.Now()
}

// SetResponseCache cache allocation response for retransmit allocation request.
func (a *Allocation) SetResponseCache(transactionID [stun.TransactionIDSize]byte, attrs []stun.Setter) {
a.responseCache.Store(&allocationResponse{
Expand All @@ -260,16 +413,16 @@ func (a *Allocation) Close() error {
}
close(a.closed)

a.lifetimeTimer.Stop()
a.Stop()

for _, p := range a.ListPermissions() {
a.RemovePermission(p.Addr)
p.lifetimeTimer.Stop()
p.stop()
}

for _, c := range a.ListChannelBindings() {
a.RemoveChannelBind(c.Number)
c.lifetimeTimer.Stop()
c.stop()
}

return a.RelaySocket.Close()
Expand Down
70 changes: 65 additions & 5 deletions internal/allocation/allocation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,78 @@
package allocation

import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"

"github.com/pion/logging"
"github.com/pion/stun/v3"
"github.com/pion/turn/v4/internal/ipnet"
"github.com/pion/turn/v4/internal/proto"
"github.com/stretchr/testify/assert"
)

func TestAllocation_MarshalUnmarshalBinary(t *testing.T) {
t.Run("Normal", func(t *testing.T) {
fiveTuple := &FiveTuple{
SrcAddr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
DstAddr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678},
}

turnSocket, err := (&net.ListenConfig{}).ListenPacket(context.Background(), "udp4", "127.0.0.1:0")
assert.NoError(t, err)
defer func() {
assert.NoError(t, turnSocket.Close())
}()

loggerFactory := logging.NewDefaultLoggerFactory()
log := loggerFactory.NewLogger("test")

alloc := NewAllocation(turnSocket, fiveTuple, EventHandler{}, log)
alloc.Protocol = UDP
alloc.username = "user"
alloc.realm = "realm"
alloc.expiresAt = time.Now().Add(time.Minute).Round(time.Second)

relayAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:1000")
assert.NoError(t, err)
alloc.RelayAddr = relayAddr

permAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:2000")
assert.NoError(t, err)
alloc.AddPermission(NewPermission(permAddr, log, DefaultPermissionTimeout))

chanAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3000")
assert.NoError(t, err)
err = alloc.AddChannelBind(NewChannelBind(proto.MinChannelNumber, chanAddr, log),
proto.DefaultLifetime,
DefaultPermissionTimeout,
)
assert.NoError(t, err)

data, err := alloc.MarshalBinary()
assert.NoError(t, err)

newAlloc := NewAllocation(nil, &FiveTuple{}, EventHandler{}, log)
// Initialize lifetimeTimer before unmarshaling, as Refresh expects it.
// Use a dummy long-running timer so it doesn't fire during the test.
newAlloc.lifetimeTimer = time.AfterFunc(time.Hour, func() {})
err = newAlloc.UnmarshalBinary(data)
assert.NoError(t, err)
assert.Equal(t, alloc.fiveTuple.String(), newAlloc.fiveTuple.String())
assert.Equal(t, alloc.Protocol, newAlloc.Protocol)
assert.Equal(t, alloc.username, newAlloc.username)
assert.Equal(t, alloc.realm, newAlloc.realm)
assert.True(t, alloc.expiresAt.Equal(newAlloc.expiresAt))
assert.Equal(t, alloc.RelayAddr.String(), newAlloc.RelayAddr.String())
assert.Equal(t, len(alloc.permissions), len(newAlloc.permissions))
assert.Equal(t, len(alloc.channelBindings), len(newAlloc.channelBindings))
})
}

func TestGetPermission(t *testing.T) {
alloc := NewAllocation(nil, nil, EventHandler{}, nil)

Expand Down Expand Up @@ -192,7 +252,7 @@ func TestAllocationRefresh(t *testing.T) {
func TestAllocationClose(t *testing.T) {
network := "udp"

l, err := net.ListenPacket(network, "0.0.0.0:0") // nolint: noctx
l, err := (&net.ListenConfig{}).ListenPacket(context.Background(), network, "0.0.0.0:0")
assert.NoError(t, err)

alloc := NewAllocation(nil, nil, EventHandler{}, nil)
Expand Down Expand Up @@ -220,11 +280,11 @@ func TestPacketHandler(t *testing.T) {
manager, _ := newTestManager()

// TURN server initialization
turnSocket, err := net.ListenPacket(network, "127.0.0.1:0") // nolint: noctx
turnSocket, err := (&net.ListenConfig{}).ListenPacket(context.Background(), network, "127.0.0.1:0")
assert.NoError(t, err)

// Client listener initialization
clientListener, err := net.ListenPacket(network, "127.0.0.1:0") // nolint: noctx
clientListener, err := (&net.ListenConfig{}).ListenPacket(context.Background(), network, "127.0.0.1:0")
assert.NoError(t, err)

dataCh := make(chan []byte)
Expand All @@ -248,10 +308,10 @@ func TestPacketHandler(t *testing.T) {

assert.NoError(t, err, "should succeed")

peerListener1, err := net.ListenPacket(network, "127.0.0.1:0") // nolint: noctx
peerListener1, err := (&net.ListenConfig{}).ListenPacket(context.Background(), network, "127.0.0.1:0")
assert.NoError(t, err)

peerListener2, err := net.ListenPacket(network, "127.0.0.1:0") // nolint: noctx
peerListener2, err := (&net.ListenConfig{}).ListenPacket(context.Background(), network, "127.0.0.1:0")
assert.NoError(t, err)

// Add permission with peer1 address
Expand Down
Loading
Loading