Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ func Dial(network, address string) (*Client, error) {

// DialConfig is used to pass configuration to DialURI().
type DialConfig struct {
DTLSConfig dtls.Config
TLSConfig tls.Config
DTLSConfig *dtls.Config
TLSConfig *tls.Config

Net transport.Net
}
Expand Down Expand Up @@ -76,7 +76,10 @@ func DialURI(uri *URI, cfg *DialConfig) (*Client, error) { //nolint:cyclop
}

case uri.Scheme == SchemeTypeTURNS && uri.Proto == ProtoTypeUDP:
dtlsCfg := cfg.DTLSConfig // Copy
var dtlsCfg dtls.Config
if cfg.DTLSConfig != nil {
dtlsCfg = *cfg.DTLSConfig
}
dtlsCfg.ServerName = uri.Host

udpAddr, err := net.ResolveUDPAddr("udp", addr)
Expand All @@ -94,7 +97,10 @@ func DialURI(uri *URI, cfg *DialConfig) (*Client, error) { //nolint:cyclop
}

case (uri.Scheme == SchemeTypeTURNS || uri.Scheme == SchemeTypeSTUNS) && uri.Proto == ProtoTypeTCP:
tlsCfg := cfg.TLSConfig //nolint:govet, copylocks
var tlsCfg tls.Config
if cfg.TLSConfig != nil {
tlsCfg = *cfg.TLSConfig.Clone()
}
tlsCfg.ServerName = uri.Host

tcpConn, err := nw.Dial("tcp", addr)
Expand Down
52 changes: 52 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package stun
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
Expand All @@ -19,6 +20,7 @@ import (
"testing"
"time"

"github.com/pion/dtls/v3"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -416,6 +418,56 @@ func TestDialURI(t *testing.T) {
}()
}

func TestDialURINilConfigs(t *testing.T) {
// DialConfig with nil TLSConfig/DTLSConfig should work the same as zero-value configs.
u, err := ParseURI("stun:localhost")
assert.NoError(t, err)

c, err := DialURI(u, &DialConfig{
TLSConfig: nil,
DTLSConfig: nil,
})
assert.NoError(t, err)
defer func() {
assert.NoError(t, c.Close())
}()
}

func TestDialURITLSConfigNotMutated(t *testing.T) {
// Verify that the caller's TLS config is not modified by DialURI.
tlsCfg := &tls.Config{
ServerName: "original-server-name",
MinVersion: tls.VersionTLS12,
}

u, err := ParseURI("stuns:localhost:3478")
assert.NoError(t, err)

// The dial will fail since there's no TLS server, but the config
// should still not be mutated regardless.
_, _ = DialURI(u, &DialConfig{TLSConfig: tlsCfg})

assert.Equal(t, "original-server-name", tlsCfg.ServerName,
"DialURI must not mutate the caller's TLS config")
}

func TestDialURIDTLSConfigNotMutated(t *testing.T) {
// Verify that the caller's DTLS config is not modified by DialURI.
dtlsCfg := &dtls.Config{
ServerName: "original-server-name",
}

u, err := ParseURI("turns:localhost:3478?transport=udp")
assert.NoError(t, err)

// The dial will fail since there's no DTLS server, but the config
// should still not be mutated regardless.
_, _ = DialURI(u, &DialConfig{DTLSConfig: dtlsCfg})

assert.Equal(t, "original-server-name", dtlsCfg.ServerName,
"DialURI must not mutate the caller's DTLS config")
}

func TestDialError(t *testing.T) {
_, err := Dial("bad?network", "?????")
assert.Error(t, err, "error expected")
Expand Down
Loading