Skip to content

Commit

Permalink
Move SRTP to streams API
Browse files Browse the repository at this point in the history
Sending has been implemented, receive is still in progress.

Relates to #272
  • Loading branch information
Sean-Der committed Jan 3, 2019
1 parent 3e73209 commit da908cc
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 61 deletions.
232 changes: 171 additions & 61 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"net"
"sync"

"github.com/pions/webrtc/pkg/rtp"
)

// SessionSRTP implements io.ReadWriteCloser and provides a bi-directional SRTP session
Expand All @@ -12,36 +14,51 @@ import (
// instead of making everyone re-implement
type SessionSRTP struct {
session
writeStream *WriteStream
}

// CreateSessionSRTP creates a new SessionSRTP
func CreateSessionSRTP(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, nextConn net.Conn) (*SessionSRTP, error) {
s := &SessionSRTP{
session{nextConn: nextConn, toRead: make(chan []byte)},
session: session{nextConn: nextConn},
}
s.writeStream = &WriteStream{s}

if err := s.session.initalize(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt, profile /* isRTP */, true); err != nil {
if err := s.session.initalize(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt, profile, s); err != nil {
return nil, err
}

return s, nil
}

// Read reads from the session and decrypts to RTP
func (s *SessionSRTP) Read(buf []byte) (int, error) {
decrypted, ok := <-s.toRead
// OpenWriteStream returns the global write stream for the Session
func (s *SessionSRTP) OpenWriteStream() (*WriteStream, error) {
return s.writeStream, nil
}

// OpenReadStream opens a read stream for the given SSRC, it can be used
// if you want a certain SSRC, but don't want to wait for AcceptStream
func (s *SessionSRTP) OpenReadStream(SSRC uint32) (*ReadStream, error) {
r, _ := s.session.getOrCreateReadStream(SSRC, s)
return r, nil
}

// AcceptStream returns a stream to handle RTCP for a single SSRC
func (s *SessionSRTP) AcceptStream() (*ReadStream, uint32, error) {
stream, ok := <-s.newStream
if !ok {
return 0, fmt.Errorf("SessionSRTP has been closed")
} else if len(decrypted) > len(buf) {
return 0, fmt.Errorf("Buffer is to small to return RTP")
return nil, 0, fmt.Errorf("SessionSRTP has been closed")
}

copy(buf, decrypted)
return len(decrypted), nil
return stream, stream.GetSSRC(), nil
}

// Close ends the session
func (s *SessionSRTP) Close() error {
return nil
}

// Write encrypts the passed RTP buffer and writes to the session
func (s *SessionSRTP) Write(buf []byte) (int, error) {
func (s *SessionSRTP) write(buf []byte) (int, error) {
s.session.localContextMutex.Lock()
defer s.session.localContextMutex.Unlock()

Expand All @@ -52,8 +69,25 @@ func (s *SessionSRTP) Write(buf []byte) (int, error) {
return s.session.nextConn.Write(encrypted)
}

// Close ends the session
func (s *SessionSRTP) Close() error {
func (s *SessionSRTP) decrypt(buf []byte) error {
decrypted, err := s.remoteContext.DecryptRTP(buf)
if err != nil {
return err
}

p := &rtp.Packet{}
if err := p.Unmarshal(decrypted); err != nil {
return err
}

r, isNew := s.session.getOrCreateReadStream(p.SSRC, s)
if r == nil {
return nil // Session has been closed
} else if isNew {
s.session.newStream <- r // Notify AcceptStream
}

r.decrypted <- decrypted
return nil
}

Expand All @@ -63,36 +97,51 @@ func (s *SessionSRTP) Close() error {
// instead of making everyone re-implement
type SessionSRTCP struct {
session
writeStream *WriteStream
}

// CreateSessionSRTCP creates a new SessionSRTCP
func CreateSessionSRTCP(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, nextConn net.Conn) (*SessionSRTCP, error) {
s := &SessionSRTCP{
session{nextConn: nextConn, toRead: make(chan []byte)},
session: session{nextConn: nextConn},
}
s.writeStream = &WriteStream{s}

if err := s.session.initalize(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt, profile /* isRTP */, false); err != nil {
if err := s.session.initalize(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt, profile, s); err != nil {
return nil, err
}

return s, nil
}

// Read reads from the session and decrypts to RTCP
func (s *SessionSRTCP) Read(buf []byte) (int, error) {
decrypted, ok := <-s.toRead
// OpenWriteStream returns the global write stream for the Session
func (s *SessionSRTCP) OpenWriteStream() (*WriteStream, error) {
return s.writeStream, nil
}

// OpenReadStream opens a read stream for the given SSRC, it can be used
// if you want a certain SSRC, but don't want to wait for AcceptStream
func (s *SessionSRTCP) OpenReadStream(SSRC uint32) (*ReadStream, error) {
r, _ := s.session.getOrCreateReadStream(SSRC, s)
return r, nil
}

// AcceptStream returns a stream to handle RTCP for a single SSRC
func (s *SessionSRTCP) AcceptStream() (*ReadStream, uint32, error) {
stream, ok := <-s.newStream
if !ok {
return 0, fmt.Errorf("SessionSRTCP has been closed")
} else if len(decrypted) > len(buf) {
return 0, fmt.Errorf("Buffer is to small to return RTCP")
return nil, 0, fmt.Errorf("SessionSRTP has been closed")
}

copy(buf, decrypted)
return len(decrypted), nil
return stream, stream.GetSSRC(), nil
}

// Close ends the session
func (s *SessionSRTCP) Close() error {
return nil
}

// Write encrypts the passed RTCP buffer and writes to the session
func (s *SessionSRTCP) Write(buf []byte) (int, error) {
func (s *SessionSRTCP) write(buf []byte) (int, error) {
s.session.localContextMutex.Lock()
defer s.session.localContextMutex.Unlock()

Expand All @@ -103,8 +152,44 @@ func (s *SessionSRTCP) Write(buf []byte) (int, error) {
return s.session.nextConn.Write(encrypted)
}

// Close ends the session
func (s *SessionSRTCP) Close() error {
func (s *SessionSRTCP) decrypt(buf []byte) error {
fmt.Println("TODO SessionSRTCP.decrypt")
// func handleRTCP(getBufferTransports func(uint32) *TransportPair, buffer []byte) {
// //decrypted packets can also be compound packets, so we have to nest our reader loop here.
// compoundPacket := rtcp.NewReader(bytes.NewReader(buffer))
// for {
// _, rawrtcp, err := compoundPacket.ReadPacket()
//
// if err != nil {
// if err == io.EOF {
// break
// }
// fmt.Println(err)
// return
// }
//
// var report rtcp.Packet
// report, _, err = rtcp.Unmarshal(rawrtcp)
// if err != nil {
// fmt.Println(err)
// return
// }
//
// f := func(ssrc uint32) {
// bufferTransport := getBufferTransports(ssrc)
// if bufferTransport != nil && bufferTransport.RTCP != nil {
// select {
// case bufferTransport.RTCP <- report:
// default:
// }
// }
// }
//
// for _, ssrc := range report.DestinationSSRC() {
// f(ssrc)
// }
// }
// }
return nil
}

Expand All @@ -115,51 +200,76 @@ type session struct {
localContextMutex sync.Mutex
localContext, remoteContext *Context

toRead chan []byte
newStream chan *ReadStream

readStreamsClosed bool
readStreams map[uint32]*ReadStream
readStreamsLock sync.Mutex

nextConn net.Conn
}

func (s *session) initalize(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, isRTP bool) error {
func (s *session) getOrCreateReadStream(ssrc uint32, child streamSession) (*ReadStream, bool) {
s.readStreamsLock.Lock()
defer s.readStreamsLock.Unlock()

if s.readStreamsClosed {
return nil, false
}

isNew := false
r, ok := s.readStreams[ssrc]
if !ok {
r = &ReadStream{s: child, decrypted: make(chan []byte), ssrc: ssrc}
s.readStreams[ssrc] = r

isNew = true
}

return r, isNew
}

func (s *session) initalize(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error {
var err error
s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile)
if err != nil {
return err
}
s.remoteContext, err = CreateContext(remoteMasterKey, remoteMasterSalt, profile)

var decryptFunc func([]byte) ([]byte, error)
if isRTP {
decryptFunc = s.remoteContext.DecryptRTP
} else {
decryptFunc = s.remoteContext.DecryptRTCP
s.remoteContext, err = CreateContext(remoteMasterKey, remoteMasterSalt, profile)
if err != nil {
return err
}

if err == nil {
go func() {
defer func() {
close(s.toRead)
}()

b := make([]byte, 8192)
for {
var i int
i, err = s.nextConn.Read(b)
if err != nil {
fmt.Println(err)
return
}

var decrypted []byte
decrypted, err = decryptFunc(b[:i])
if err != nil {
fmt.Println(err)
return
}

s.toRead <- decrypted
s.readStreams = map[uint32]*ReadStream{}
s.newStream = make(chan *ReadStream)

go func() {
defer func() {
close(s.newStream)

s.readStreamsLock.Lock()
s.readStreamsClosed = true
for _, r := range s.readStreams {
close(r.decrypted)
}
s.readStreamsLock.Unlock()
}()
}
return err

b := make([]byte, 8192)
for {
var i int
i, err = s.nextConn.Read(b)
if err != nil {
fmt.Println(err)
return
}

if err = child.decrypt(b[:i]); err != nil {
fmt.Println(err)
return
}
}
}()
return nil
}
46 changes: 46 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package srtp

import "fmt"

type streamSession interface {
Close() error
write([]byte) (int, error)
decrypt([]byte) error
}

// ReadStream handles decryption for a single SSRC
type ReadStream struct {
s streamSession

decrypted chan []byte
ssrc uint32
}

// GetSSRC returns the SSRC this ReadStream gets data for
func (r *ReadStream) GetSSRC() uint32 {
return r.ssrc
}

// Read reads decrypted packets from the stream
func (r *ReadStream) Read(buf []byte) (int, error) {
decrypted, ok := <-r.decrypted
if !ok {
return 0, fmt.Errorf("Stream has been closed")
} else if len(decrypted) > len(buf) {
return 0, fmt.Errorf("Buffer is to small to copy")
}

copy(buf, decrypted)
return len(decrypted), nil
}

// WriteStream is stream for a single Session that is used to encrypt
// RTP or RTCP
type WriteStream struct {
session streamSession
}

// Write encrypts the passed RTP/RTCP buffer and writes to the session
func (w *WriteStream) Write(buf []byte) (int, error) {
return w.session.write(buf)
}

0 comments on commit da908cc

Please sign in to comment.