Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 194 additions & 26 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package turn

import (
b64 "encoding/base64"
"encoding/binary"
"fmt"
"math"
"net"
Expand All @@ -21,6 +22,11 @@ const (
maxDataBufferSize = math.MaxUint16 //message size limit for Chromium
)

const (
ProtoUDP = proto.ProtoUDP
ProtoTCP = proto.ProtoTCP
)

// interval [msec]
// 0: 0 ms +500
// 1: 500 ms +1000
Expand All @@ -43,6 +49,9 @@ type ClientConfig struct {
Conn net.PacketConn // Listening socket (net.PacketConn)
LoggerFactory logging.LoggerFactory
Net *vnet.Net

TransportProtocol proto.Protocol // Protocol to peer, UDP: 17 (default) or TCP: 6
ConnectionAttemptHandler func(ConnectionID) // Incoming TCP connections
}

// Client is a STUN server client
Expand All @@ -66,6 +75,10 @@ type Client struct {
mutex sync.RWMutex // thread-safe
mutexTrMap sync.Mutex // thread-safe
log logging.LeveledLogger // read-only

transportProtocol proto.Protocol
nonce stun.Nonce
caHandler func(ConnectionID)
}

// NewClient returns a new Client instance. listeningAddress is the address and port to listen on, default "0.0.0.0:0"
Expand All @@ -87,23 +100,49 @@ func NewClient(config *ClientConfig) (*Client, error) {
log.Warn("vnet is enabled")
}

switch config.TransportProtocol {
case 0:
config.TransportProtocol = ProtoUDP
case proto.ProtoTCP, proto.ProtoUDP:
default:
return nil, fmt.Errorf("unsupported protocol: %v", config.TransportProtocol)
}

var stunServ, turnServ net.Addr
var stunServStr, turnServStr string
var err error
if len(config.STUNServerAddr) > 0 {
log.Debugf("resolving %s", config.STUNServerAddr)
stunServ, err = config.Net.ResolveUDPAddr("udp4", config.STUNServerAddr)
if err != nil {
return nil, err
switch config.TransportProtocol {
case ProtoUDP:
stunServ, err = config.Net.ResolveUDPAddr("udp4", config.STUNServerAddr)
if err != nil {
return nil, err
}
case ProtoTCP:
// TODO: switch to vnet
stunServ, err = net.ResolveTCPAddr("tcp4", config.STUNServerAddr)
if err != nil {
return nil, err
}
}
stunServStr = stunServ.String()
log.Debugf("stunServ: %s", stunServStr)
}
if len(config.TURNServerAddr) > 0 {
log.Debugf("resolving %s", config.TURNServerAddr)
turnServ, err = config.Net.ResolveUDPAddr("udp4", config.TURNServerAddr)
if err != nil {
return nil, err
switch config.TransportProtocol {
case ProtoUDP:
turnServ, err = config.Net.ResolveUDPAddr("udp4", config.TURNServerAddr)
if err != nil {
return nil, err
}
case ProtoTCP:
// TODO: switch to vnet
turnServ, err = net.ResolveTCPAddr("tcp4", config.TURNServerAddr)
if err != nil {
return nil, err
}
}
turnServStr = turnServ.String()
log.Debugf("turnServ: %s", turnServStr)
Expand All @@ -115,19 +154,21 @@ func NewClient(config *ClientConfig) (*Client, error) {
}

c := &Client{
conn: config.Conn,
stunServ: stunServ,
turnServ: turnServ,
stunServStr: stunServStr,
turnServStr: turnServStr,
username: stun.NewUsername(config.Username),
password: config.Password,
realm: stun.NewRealm(config.Realm),
software: stun.NewSoftware(config.Software),
net: config.Net,
trMap: client.NewTransactionMap(),
rto: rto,
log: log,
conn: config.Conn,
stunServ: stunServ,
turnServ: turnServ,
stunServStr: stunServStr,
turnServStr: turnServStr,
username: stun.NewUsername(config.Username),
password: config.Password,
realm: stun.NewRealm(config.Realm),
software: stun.NewSoftware(config.Software),
net: config.Net,
trMap: client.NewTransactionMap(),
rto: rto,
log: log,
transportProtocol: config.TransportProtocol,
caHandler: config.ConnectionAttemptHandler,
}

return c, nil
Expand Down Expand Up @@ -249,7 +290,7 @@ func (c *Client) Allocate() (net.PacketConn, error) {
msg, err := stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodAllocate, stun.ClassRequest),
proto.RequestedTransport{Protocol: proto.ProtoUDP},
proto.RequestedTransport{Protocol: c.transportProtocol},
stun.Fingerprint,
)
if err != nil {
Expand All @@ -264,8 +305,7 @@ func (c *Client) Allocate() (net.PacketConn, error) {
res := trRes.Msg

// Anonymous allocate failed, trying to authenticate.
var nonce stun.Nonce
if err = nonce.GetFrom(res); err != nil {
if err = c.nonce.GetFrom(res); err != nil {
return nil, err
}
if err = c.realm.GetFrom(res); err != nil {
Expand All @@ -279,10 +319,10 @@ func (c *Client) Allocate() (net.PacketConn, error) {
msg, err = stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodAllocate, stun.ClassRequest),
proto.RequestedTransport{Protocol: proto.ProtoUDP},
proto.RequestedTransport{Protocol: c.transportProtocol},
&c.username,
&c.realm,
&nonce,
&c.nonce,
&c.integrity,
stun.Fingerprint,
)
Expand Down Expand Up @@ -324,7 +364,7 @@ func (c *Client) Allocate() (net.PacketConn, error) {
Observer: c,
RelayedAddr: relayedAddr,
Integrity: c.integrity,
Nonce: nonce,
Nonce: c.nonce,
Lifetime: lifetime.Duration,
Log: c.log,
})
Expand Down Expand Up @@ -434,7 +474,8 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
}

if msg.Type.Class == stun.ClassIndication {
if msg.Type.Method == stun.MethodData {
switch msg.Type.Method {
case stun.MethodData:
var peerAddr proto.PeerAddress
if err := peerAddr.GetFrom(msg); err != nil {
return err
Expand All @@ -458,6 +499,15 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
}

relayedConn.HandleInbound(data, from)
case stun.MethodConnectionAttempt:
var cid ConnectionID
if err := cid.GetFrom(msg); err != nil {
return err
}

if c.caHandler != nil {
c.caHandler(cid)
}
}
return nil
}
Expand Down Expand Up @@ -568,3 +618,121 @@ func (c *Client) relayedUDPConn() *client.UDPConn {

return c.relayedConn
}

// Connect initiates a new TCP connection to a peer
func (c *Client) Connect(peer *net.TCPAddr) (ConnectionID, error) {
msg := stun.New()
msg.WriteHeader()
stun.TransactionID.AddTo(msg)
stun.NewType(stun.MethodConnect, stun.ClassRequest).AddTo(msg)
stun.XORMappedAddress{
IP: peer.IP,
Port: peer.Port,
}.AddToAs(msg, stun.AttrXORPeerAddress)
c.username.AddTo(msg)
c.nonce.AddTo(msg)
c.realm.AddTo(msg)
c.integrity.AddTo(msg)
stun.Fingerprint.AddTo(msg)

trRes, err := c.PerformTransaction(msg, c.turnServ, false)
if err != nil {
return 0, err
}
res := trRes.Msg

if res.Type.Class == stun.ClassErrorResponse {
var code stun.ErrorCodeAttribute
if err = code.GetFrom(res); err == nil {
return 0, fmt.Errorf("%s (error %s)", res.Type, code)
}
return 0, fmt.Errorf("%s", res.Type)
}

var cid ConnectionID
err = cid.GetFrom(res)
if err != nil {
return 0, err
}

return cid, nil
}

// ConnectionBind associates the given tcp connection with the remote connection ID.
// After a successful return the connection can be used normally.
func (c *Client) ConnectionBind(dataConn net.Conn, cid ConnectionID) error {
msg, err := stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodConnectionBind, stun.ClassRequest),
cid,
&c.username,
&c.realm,
&c.nonce,
&c.integrity,
stun.Fingerprint,
)
if err != nil {
return err
}

_, err = dataConn.Write(msg.Raw)
if err != nil {
return err
}

// read exactly one STUN message,
// any data after belongs to the user
b := make([]byte, stunHeaderSize)
n, err := dataConn.Read(b)
if n != stunHeaderSize {
return errIncompleteTURNFrame
} else if err != nil {
return err
}
if !stun.IsMessage(b) {
return errInvalidTURNFrame
}

datagramSize := binary.BigEndian.Uint16(b[2:4]) + stunHeaderSize
raw := make([]byte, datagramSize)
copy(raw, b)
_, err = dataConn.Read(raw[stunHeaderSize:])
if err != nil {
return err
}
res := &stun.Message{Raw: raw}
if err := res.Decode(); err != nil {
return fmt.Errorf("failed to decode STUN message: %s", err.Error())
}

switch res.Type.Class {
case stun.ClassErrorResponse:
var code stun.ErrorCodeAttribute
if err = code.GetFrom(res); err == nil {
return fmt.Errorf("%s (error %s)", res.Type, code)
}
return fmt.Errorf("%s", res.Type)
case stun.ClassSuccessResponse:
return nil
default:
return fmt.Errorf("unexpected STUN request message: %s", res.String())
}
}

type ConnectionID uint32

func (c ConnectionID) AddTo(m *stun.Message) error {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(c))
m.Add(stun.AttrConnectionID, b)
return nil
}

func (c *ConnectionID) GetFrom(m *stun.Message) error {
b, err := m.Get(stun.AttrConnectionID)
if err != nil {
return err
}
*c = ConnectionID(binary.BigEndian.Uint32(b))
return nil
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ require (
github.com/pion/stun v0.3.5
github.com/pion/transport v0.10.0
github.com/stretchr/testify v1.6.1
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a
)
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b h1:0mm1VjtFUOIlE1SbDlwjYaDxZVDP2S5ou6y0gSgXHu8=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
Expand Down
Loading