Skip to content

Commit

Permalink
Migrate to use ConnCtx
Browse files Browse the repository at this point in the history
Use Conn controlled by context.Context.
If you want to use net.Conn as a underlying connection,
wrapping it by github.com/pion/transport/ctxconn.ConnCtx would be the
easiest way.
  • Loading branch information
at-wat committed Nov 29, 2020
1 parent 1354c92 commit 8cee484
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 97 deletions.
12 changes: 12 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package srtp

import (
"context"
)

// ConnCtx is a Conn controlled by context.Context instead of SetDeadline.
type ConnCtx interface {
ReadContext(context.Context, []byte) (int, error)
WriteContext(context.Context, []byte) (int, error)
Close() error
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ require (
github.com/pion/logging v0.2.2
github.com/pion/rtcp v1.2.4
github.com/pion/rtp v1.6.1
github.com/pion/transport v0.10.1
github.com/pion/transport v0.11.0
github.com/stretchr/testify v1.6.1
)
12 changes: 7 additions & 5 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ github.com/pion/rtcp v1.2.4 h1:NT3H5LkUGgaEapvp0HGik+a+CpflRF7KTD7H+o7OWIM=
github.com/pion/rtcp v1.2.4/go.mod h1:52rMNPWFsjr39z9B9MhnkqhPLoeHTv1aN63o/42bWE0=
github.com/pion/rtp v1.6.1 h1:2Y2elcVBrahYnHKN2X7rMHX/r1R4TEBMP1LaVu/wNhk=
github.com/pion/rtp v1.6.1/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko=
github.com/pion/transport v0.10.1 h1:2W+yJT+0mOQ160ThZYUx5Zp2skzshiNgxrNE9GUfhJM=
github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+DwKOzf4A=
github.com/pion/transport v0.11.0 h1:Z1RhzqrWPPYj5Xed8P7pirTKTvXFoxDI3uJuuKu6akM=
github.com/pion/transport v0.11.0/go.mod h1:ORH8Ouyl1enoJyHwU+MwMeQocWbeorEk5068FOsHjog=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand All @@ -20,12 +20,14 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4=
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102 h1:42cLlJJdEh+ySyeUUbEQ5bsTiq8voBeTuweGVkY6Puw=
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
Expand Down
14 changes: 7 additions & 7 deletions session.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package srtp

import (
"context"
"io"
"net"
"sync"

"github.com/pion/logging"
)

type streamSession interface {
Close() error
write([]byte) (int, error)
decrypt([]byte) error
write(context.Context, []byte) (int, error)
decrypt(context.Context, []byte) error
}

type session struct {
Expand All @@ -30,7 +30,7 @@ type session struct {

log logging.LeveledLogger

nextConn net.Conn
nextConn ConnCtx
}

// Config is used to configure a session.
Expand Down Expand Up @@ -102,7 +102,7 @@ func (s *session) close() error {
return nil
}

func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error {
func (s *session) start(ctx context.Context, localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error {
var err error
s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...)
if err != nil {
Expand All @@ -127,15 +127,15 @@ func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remote
b := make([]byte, 8192)
for {
var i int
i, err = s.nextConn.Read(b)
i, err = s.nextConn.ReadContext(ctx, b)
if err != nil {
if err != io.EOF {
s.log.Error(err.Error())
}
return
}

if err = child.decrypt(b[:i]); err != nil {
if err = child.decrypt(ctx, b[:i]); err != nil {
s.log.Info(err.Error())
}
}
Expand Down
13 changes: 7 additions & 6 deletions session_srtcp.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package srtp

import (
"net"
"context"

"github.com/pion/logging"
"github.com/pion/rtcp"
Expand All @@ -19,7 +19,7 @@ type SessionSRTCP struct {
}

// NewSessionSRTCP creates a SRTCP session using conn as the underlying transport.
func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //nolint:dupl
func NewSessionSRTCP(ctx context.Context, conn ConnCtx, config *Config) (*SessionSRTCP, error) { //nolint:dupl
if config == nil {
return nil, errNoConfig
} else if conn == nil {
Expand Down Expand Up @@ -58,6 +58,7 @@ func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //n
s.writeStream = &WriteStreamSRTCP{s}

err := s.session.start(
ctx,
config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt,
config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt,
config.Profile,
Expand Down Expand Up @@ -107,7 +108,7 @@ func (s *SessionSRTCP) Close() error {

// Private

func (s *SessionSRTCP) write(buf []byte) (int, error) {
func (s *SessionSRTCP) write(ctx context.Context, buf []byte) (int, error) {
if _, ok := <-s.session.started; ok {
return 0, errStartedChannelUsedIncorrectly
}
Expand All @@ -119,7 +120,7 @@ func (s *SessionSRTCP) write(buf []byte) (int, error) {
if err != nil {
return 0, err
}
return s.session.nextConn.Write(encrypted)
return s.session.nextConn.WriteContext(ctx, encrypted)
}

// create a list of Destination SSRCs
Expand All @@ -140,7 +141,7 @@ func destinationSSRC(pkts []rtcp.Packet) []uint32 {
return out
}

func (s *SessionSRTCP) decrypt(buf []byte) error {
func (s *SessionSRTCP) decrypt(ctx context.Context, buf []byte) error {
decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil)
if err != nil {
return err
Expand All @@ -164,7 +165,7 @@ func (s *SessionSRTCP) decrypt(buf []byte) error {
return errFailedTypeAssertion
}

_, err = readStream.write(decrypted)
_, err = readStream.write(ctx, decrypted)
if err != nil {
return err
}
Expand Down
43 changes: 27 additions & 16 deletions session_srtcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,34 @@ package srtp

import (
"bytes"
"context"
"io"
"net"
"reflect"
"sync"
"testing"
"time"

"github.com/pion/rtcp"
"github.com/pion/transport/connctx"
"github.com/pion/transport/test"
)

const rtcpHeaderSize = 4

func TestSessionSRTCPBadInit(t *testing.T) {
if _, err := NewSessionSRTCP(nil, nil); err == nil {
ctx := context.Background()

if _, err := NewSessionSRTCP(ctx, nil, nil); err == nil {
t.Fatal("NewSessionSRTCP should error if no config was provided")
} else if _, err := NewSessionSRTCP(nil, &Config{}); err == nil {
} else if _, err := NewSessionSRTCP(ctx, nil, &Config{}); err == nil {
t.Fatal("NewSessionSRTCP should error if no net was provided")
}
}

func buildSessionSRTCPPair(t *testing.T) (*SessionSRTCP, *SessionSRTCP) { //nolint:dupl
aPipe, bPipe := net.Pipe()
ctx := context.Background()

aPipe, bPipe := connctx.Pipe()
config := &Config{
Profile: ProtectionProfileAes128CmHmacSha1_80,
Keys: SessionKeys{
Expand All @@ -35,14 +40,14 @@ func buildSessionSRTCPPair(t *testing.T) (*SessionSRTCP, *SessionSRTCP) { //noli
},
}

aSession, err := NewSessionSRTCP(aPipe, config)
aSession, err := NewSessionSRTCP(ctx, aPipe, config)
if err != nil {
t.Fatal(err)
} else if aSession == nil {
t.Fatal("NewSessionSRTCP did not error, but returned nil session")
}

bSession, err := NewSessionSRTCP(bPipe, config)
bSession, err := NewSessionSRTCP(ctx, bPipe, config)
if err != nil {
t.Fatal(err)
} else if bSession == nil {
Expand All @@ -53,6 +58,8 @@ func buildSessionSRTCPPair(t *testing.T) (*SessionSRTCP, *SessionSRTCP) { //noli
}

func TestSessionSRTCP(t *testing.T) {
ctx := context.Background()

lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

Expand All @@ -71,7 +78,7 @@ func TestSessionSRTCP(t *testing.T) {
t.Fatal(err)
}

if _, err = aWriteStream.Write(testPayload); err != nil {
if _, err = aWriteStream.Write(ctx, testPayload); err != nil {
t.Fatal(err)
}

Expand All @@ -80,7 +87,7 @@ func TestSessionSRTCP(t *testing.T) {
t.Fatal(err)
}

if _, err = bReadStream.Read(readBuffer); err != nil {
if _, err = bReadStream.Read(ctx, readBuffer); err != nil {
t.Fatal(err)
}

Expand All @@ -98,6 +105,8 @@ func TestSessionSRTCP(t *testing.T) {
}

func TestSessionSRTCPOpenReadStream(t *testing.T) {
ctx := context.Background()

lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

Expand All @@ -121,11 +130,11 @@ func TestSessionSRTCPOpenReadStream(t *testing.T) {
t.Fatal(err)
}

if _, err = aWriteStream.Write(testPayload); err != nil {
if _, err = aWriteStream.Write(ctx, testPayload); err != nil {
t.Fatal(err)
}

if _, err = bReadStream.Read(readBuffer); err != nil {
if _, err = bReadStream.Read(ctx, readBuffer); err != nil {
t.Fatal(err)
}

Expand All @@ -143,6 +152,8 @@ func TestSessionSRTCPOpenReadStream(t *testing.T) {
}

func TestSessionSRTCPReplayProtection(t *testing.T) {
ctx := context.Background()

lim := test.TimeOut(time.Second * 5)
defer lim.Stop()

Expand Down Expand Up @@ -181,7 +192,7 @@ func TestSessionSRTCPReplayProtection(t *testing.T) {
go func() {
defer wg.Done()
for {
if ssrc, perr := getSenderSSRC(t, bReadStream); perr == nil {
if ssrc, perr := getSenderSSRC(ctx, t, bReadStream); perr == nil {
receivedSSRC = append(receivedSSRC, ssrc)
} else if perr == io.EOF {
return
Expand All @@ -191,17 +202,17 @@ func TestSessionSRTCPReplayProtection(t *testing.T) {

// Write with replay attack
for _, p := range packets {
if _, err = aSession.session.nextConn.Write(p); err != nil {
if _, err = aSession.session.nextConn.WriteContext(ctx, p); err != nil {
t.Fatal(err)
}
// Immediately replay
if _, err = aSession.session.nextConn.Write(p); err != nil {
if _, err = aSession.session.nextConn.WriteContext(ctx, p); err != nil {
t.Fatal(err)
}
}
for _, p := range packets {
// Delayed replay
if _, err = aSession.session.nextConn.Write(p); err != nil {
if _, err = aSession.session.nextConn.WriteContext(ctx, p); err != nil {
t.Fatal(err)
}
}
Expand All @@ -224,15 +235,15 @@ func TestSessionSRTCPReplayProtection(t *testing.T) {
}
}

func getSenderSSRC(t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err error) {
func getSenderSSRC(ctx context.Context, t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err error) {
authTagSize, err := ProtectionProfileAes128CmHmacSha1_80.authTagLen()
if err != nil {
return 0, err
}

const pliPacketSize = 8
readBuffer := make([]byte, pliPacketSize+authTagSize+srtcpIndexSize)
n, _, err := stream.ReadRTCP(readBuffer)
n, _, err := stream.ReadRTCP(ctx, readBuffer)
if err == io.EOF {
return 0, err
}
Expand Down
Loading

0 comments on commit 8cee484

Please sign in to comment.