Skip to content

Commit

Permalink
Use new SRTP API, instead of just Read/Write
Browse files Browse the repository at this point in the history
Relates to #272
  • Loading branch information
Sean-Der committed Jan 13, 2019
1 parent 90016b6 commit 1d79249
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 47 deletions.
9 changes: 3 additions & 6 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,16 @@ func (s *session) getOrCreateReadStream(ssrc uint32, child streamSession, proto
return nil, false
}

isNew := false
r, ok := s.readStreams[ssrc]
if !ok {
if err := proto.init(child, ssrc); err != nil {
return nil, false
}

r = proto
isNew = true
s.readStreams[ssrc] = r
s.readStreams[ssrc] = proto
return proto, true
}

return r, isNew
return r, false
}

func (s *session) initalize() {
Expand Down
21 changes: 8 additions & 13 deletions session_srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,18 @@ import (
"github.com/pions/webrtc/pkg/rtcp"
)

type readResultSRTCP struct {
len int
header *rtcp.Header
}

// SessionSRTCP implements io.ReadWriteCloser and provides a bi-directional SRTCP session
// SRTCP itself does not have a design like this, but it is common in most applications
// for local/remote to each have their own keying material. This provides those patterns
// instead of making everyone re-implement
type SessionSRTCP struct {
session
writeStream *WriteStreamSRTCP
readCh chan []byte
readRetCh chan readResultSRTCP
}

// CreateSessionSRTCP creates a new SessionSRTCP
func CreateSessionSRTCP() *SessionSRTCP {
s := &SessionSRTCP{
readCh: make(chan []byte),
readRetCh: make(chan readResultSRTCP),
}
s := &SessionSRTCP{}
s.writeStream = &WriteStreamSRTCP{s}
s.session.initalize()
return s
Expand Down Expand Up @@ -125,15 +115,20 @@ func (s *SessionSRTCP) decrypt(buf []byte) error {
s.session.newStream <- r // Notify AcceptStream
}

readBuf := <-s.readCh
readStream, ok := r.(*ReadStreamSRTCP)
if !ok {
return fmt.Errorf("Failed to get/create ReadStreamSRTP")
}

readBuf := <-readStream.readCh
if len(readBuf) < len(decrypted) {
return fmt.Errorf("Input buffer was not long enough to contain decrypted RTCP")
}

copy(readBuf, decrypted)
h := report.Header()

s.readRetCh <- readResultSRTCP{
readStream.readRetCh <- readResultSRTCP{
len: len(decrypted),
header: &h,
}
Expand Down
22 changes: 9 additions & 13 deletions session_srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,18 @@ import (
"github.com/pions/webrtc/pkg/rtp"
)

type readResultSRTP struct {
len int
header *rtp.Header
}

// SessionSRTP implements io.ReadWriteCloser and provides a bi-directional SRTP session
// SRTP itself does not have a design like this, but it is common in most applications
// for local/remote to each have their own keying material. This provides those patterns
// instead of making everyone re-implement
type SessionSRTP struct {
session
writeStream *WriteStreamSRTP
readCh chan []byte
readRetCh chan readResultSRTP
}

// CreateSessionSRTP creates a new SessionSRTP
func CreateSessionSRTP() *SessionSRTP {
s := &SessionSRTP{
readCh: make(chan []byte),
readRetCh: make(chan readResultSRTP),
}
s := &SessionSRTP{}
s.writeStream = &WriteStreamSRTP{s}
s.session.initalize()
return s
Expand Down Expand Up @@ -104,17 +94,23 @@ func (s *SessionSRTP) decrypt(buf []byte) error {
s.session.newStream <- r // Notify AcceptStream
}

readBuf := <-s.readCh
readStream, ok := r.(*ReadStreamSRTP)
if !ok {
return fmt.Errorf("Failed to get/create ReadStreamSRTP")
}

readBuf := <-readStream.readCh
decrypted, err := s.remoteContext.decryptRTP(readBuf, buf, h)
if err != nil {
return err
} else if len(decrypted) > len(readBuf) {
return fmt.Errorf("Input buffer was not long enough to contain decrypted RTP")
}

s.readRetCh <- readResultSRTP{
readStream.readRetCh <- readResultSRTP{
len: len(decrypted),
header: h,
}

return nil
}
22 changes: 15 additions & 7 deletions stream_srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,31 @@ import (
"github.com/pions/webrtc/pkg/rtcp"
)

type readResultSRTCP struct {
len int
header *rtcp.Header
}

// ReadStreamSRTCP handles decryption for a single RTCP SSRC
type ReadStreamSRTCP struct {
session *SessionSRTCP
ssrc uint32
session *SessionSRTCP
ssrc uint32
readCh chan []byte
readRetCh chan readResultSRTCP
}

// ReadRTCP reads and decrypts full RTCP packet and its header from the nextConn
func (r *ReadStreamSRTCP) ReadRTCP(payload []byte) (int, *rtcp.Header, error) {
select {
case <-r.session.closed:
return 0, nil, fmt.Errorf("SRTCP session is closed")
case r.session.readCh <- payload:
case r.readCh <- payload:
}

select {
case <-r.session.closed:
return 0, nil, fmt.Errorf("SRTCP session is closed")
case res := <-r.session.readRetCh:
case res := <-r.readRetCh:
return res.len, res.header, nil
}
}
Expand All @@ -33,13 +40,13 @@ func (r *ReadStreamSRTCP) Read(b []byte) (int, error) {
select {
case <-r.session.closed:
return 0, fmt.Errorf("SRTCP session is closed")
case r.session.readCh <- b:
case r.readCh <- b:
}

select {
case <-r.session.closed:
return 0, fmt.Errorf("SRTCP session is closed")
case res := <-r.session.readRetCh:
case res := <-r.readRetCh:
return res.len, nil
}
}
Expand All @@ -52,8 +59,9 @@ func (r *ReadStreamSRTCP) init(child streamSession, ssrc uint32) error {

r.session = sessionSRTCP
r.ssrc = ssrc
r.readCh = make(chan []byte)
r.readRetCh = make(chan readResultSRTCP)
return nil

}

// GetSSRC returns the SSRC we are demuxing for
Expand Down
24 changes: 16 additions & 8 deletions stream_srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,34 @@ package srtp
import (
"fmt"

"github.com/pions/webrtc/pkg/rtcp"
"github.com/pions/webrtc/pkg/rtp"
)

type readResultSRTP struct {
len int
header *rtp.Header
}

// ReadStreamSRTP handles decryption for a single RTP SSRC
type ReadStreamSRTP struct {
session *SessionSRTP
ssrc uint32
session *SessionSRTP
ssrc uint32
readCh chan []byte
readRetCh chan readResultSRTP
}

// ReadRTP reads and decrypts full RTP packet and its header from the nextConn
func (r *ReadStreamSRTP) ReadRTP(payload []byte) (int, *rtp.Header, error) {
select {
case <-r.session.closed:
return 0, nil, fmt.Errorf("SRTP session is closed")
case r.session.readCh <- payload:
case r.readCh <- payload:
}

select {
case <-r.session.closed:
return 0, nil, fmt.Errorf("SRTP session is closed")
case res := <-r.session.readRetCh:
case res := <-r.readRetCh:
return res.len, res.header, nil
}
}
Expand All @@ -34,13 +40,13 @@ func (r *ReadStreamSRTP) Read(b []byte) (int, error) {
select {
case <-r.session.closed:
return 0, fmt.Errorf("SRTP session is closed")
case r.session.readCh <- b:
case r.readCh <- b:
}

select {
case <-r.session.closed:
return 0, fmt.Errorf("SRTP session is closed")
case res := <-r.session.readRetCh:
case res := <-r.readRetCh:
return res.len, nil
}
}
Expand All @@ -53,6 +59,8 @@ func (r *ReadStreamSRTP) init(child streamSession, ssrc uint32) error {

r.session = sessionSRTP
r.ssrc = ssrc
r.readCh = make(chan []byte)
r.readRetCh = make(chan readResultSRTP)
return nil
}

Expand All @@ -67,7 +75,7 @@ type WriteStreamSRTP struct {
}

// WriteRTP encrypts a RTP header and its payload to the nextConn
func (w *WriteStreamSRTP) WriteRTP(header *rtcp.Header, payload []byte) (int, error) {
func (w *WriteStreamSRTP) WriteRTP(header *rtp.Header, payload []byte) (int, error) {
headerRaw, err := header.Marshal()
if err != nil {
return 0, err
Expand Down

0 comments on commit 1d79249

Please sign in to comment.