Skip to content

Commit

Permalink
Add the ability to close ReadStreams
Browse files Browse the repository at this point in the history
This allows us to close/remove ReadStreams from a session at anytime.

Resolves #3
  • Loading branch information
Sean-Der committed Feb 6, 2019
1 parent 93a8f4f commit a95c261
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 5 deletions.
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ linters:
- dupl
- gocritic
- gochecknoglobals
- maligned

issues:
exclude-use-default: false
11 changes: 11 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ func (s *session) getOrCreateReadStream(ssrc uint32, child streamSession, proto
return r, false
}

func (s *session) removeReadStream(ssrc uint32) {
s.readStreamsLock.Lock()
defer s.readStreamsLock.Unlock()

if s.readStreamsClosed {
return
}

delete(s.readStreams, ssrc)
}

func (s *session) close() error {
if s.nextConn == nil {
return nil
Expand Down
4 changes: 4 additions & 0 deletions session_srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ func (s *SessionSRTCP) decrypt(buf []byte) error {
return fmt.Errorf("failed to get/create ReadStreamSRTP")
}

// Ensure that readStream.Close() isn't called while in flight
readStream.mu.Lock()
defer readStream.mu.Unlock()

readBuf := <-readStream.readCh
if len(readBuf) < len(decrypted) {
return fmt.Errorf("input buffer was not long enough to contain decrypted RTCP")
Expand Down
4 changes: 4 additions & 0 deletions session_srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ func (s *SessionSRTP) decrypt(buf []byte) error {
return fmt.Errorf("failed to get/create ReadStreamSRTP")
}

// Ensure that readStream.Close() isn't called while in flight
readStream.mu.Lock()
defer readStream.mu.Unlock()

readBuf := <-readStream.readCh
decrypted, err := s.remoteContext.decryptRTP(readBuf, buf, h)
if err != nil {
Expand Down
48 changes: 45 additions & 3 deletions stream_srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package srtp

import (
"fmt"
"sync"

"github.com/pions/rtcp"
)
Expand All @@ -13,6 +14,11 @@ type readResultSRTCP struct {

// ReadStreamSRTCP handles decryption for a single RTCP SSRC
type ReadStreamSRTCP struct {
mu sync.Mutex

isInited bool
isClosed chan bool

session *SessionSRTCP
ssrc uint32
readCh chan []byte
Expand All @@ -25,12 +31,18 @@ func (r *ReadStreamSRTCP) ReadRTCP(payload []byte) (int, *rtcp.Header, error) {
case <-r.session.closed:
return 0, nil, fmt.Errorf("SRTCP session is closed")
case r.readCh <- payload:
case <-r.isClosed:
return 0, nil, fmt.Errorf("SRTCP read stream is closed")
}

select {
case <-r.session.closed:
return 0, nil, fmt.Errorf("SRTCP session is closed")
case res := <-r.readRetCh:
case res, ok := <-r.readRetCh:
if !ok {
return 0, nil, fmt.Errorf("SRTCP read stream is closed")
}

return res.len, res.header, nil
}
}
Expand All @@ -41,26 +53,56 @@ func (r *ReadStreamSRTCP) Read(b []byte) (int, error) {
case <-r.session.closed:
return 0, fmt.Errorf("SRTCP session is closed")
case r.readCh <- b:
case <-r.isClosed:
return 0, fmt.Errorf("SRTCP read stream is closed")
}

select {
case <-r.session.closed:
return 0, fmt.Errorf("SRTCP session is closed")
case res := <-r.readRetCh:
case res, ok := <-r.readRetCh:
if !ok {
return 0, fmt.Errorf("SRTCP read stream is closed")
}
return res.len, nil
}
}

// Close removes the ReadStream from the session and cleans up any associated state
func (r *ReadStreamSRTCP) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if !r.isInited {
return fmt.Errorf("ReadStreamSRTCP has not been inited")
}

select {
case <-r.isClosed:
return fmt.Errorf("ReadStreamSRTCP is already closed")
default:
close(r.readRetCh)
r.session.removeReadStream(r.ssrc)
return nil
}
}

func (r *ReadStreamSRTCP) init(child streamSession, ssrc uint32) error {
sessionSRTCP, ok := child.(*SessionSRTCP)

r.mu.Lock()
defer r.mu.Unlock()
if !ok {
return fmt.Errorf("ReadStreamSRTCP init failed type assertion")
} else if r.isInited {
return fmt.Errorf("ReadStreamSRTCP has already been inited")
}

r.session = sessionSRTCP
r.ssrc = ssrc
r.readCh = make(chan []byte)
r.readRetCh = make(chan readResultSRTCP)
r.isInited = true
r.isClosed = make(chan bool)
return nil
}

Expand All @@ -84,7 +126,7 @@ func (w *WriteStreamSRTCP) WriteRTCP(header *rtcp.Header, payload []byte) (int,
return w.session.write(append(headerRaw, payload...))
}

// Write encrypts and writes a full RTP packets to the nextConn
// Write encrypts and writes a full RTCP packets to the nextConn
func (w *WriteStreamSRTCP) Write(b []byte) (int, error) {
return w.session.write(b)
}
45 changes: 43 additions & 2 deletions stream_srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package srtp

import (
"fmt"
"sync"

"github.com/pions/rtp"
)
Expand All @@ -13,6 +14,11 @@ type readResultSRTP struct {

// ReadStreamSRTP handles decryption for a single RTP SSRC
type ReadStreamSRTP struct {
mu sync.Mutex

isInited bool
isClosed chan bool

session *SessionSRTP
ssrc uint32
readCh chan []byte
Expand All @@ -25,12 +31,17 @@ func (r *ReadStreamSRTP) ReadRTP(payload []byte) (int, *rtp.Header, error) {
case <-r.session.closed:
return 0, nil, fmt.Errorf("SRTP session is closed")
case r.readCh <- payload:
case <-r.isClosed:
return 0, nil, fmt.Errorf("SRTP read stream is closed")
}

select {
case <-r.session.closed:
return 0, nil, fmt.Errorf("SRTP session is closed")
case res := <-r.readRetCh:
case res, ok := <-r.readRetCh:
if !ok {
return 0, nil, fmt.Errorf("SRTP read stream is closed")
}
return res.len, res.header, nil
}
}
Expand All @@ -41,26 +52,56 @@ func (r *ReadStreamSRTP) Read(b []byte) (int, error) {
case <-r.session.closed:
return 0, fmt.Errorf("SRTP session is closed")
case r.readCh <- b:
case <-r.isClosed:
return 0, fmt.Errorf("SRTP read stream is closed")
}

select {
case <-r.session.closed:
return 0, fmt.Errorf("SRTP session is closed")
case res := <-r.readRetCh:
case res, ok := <-r.readRetCh:
if !ok {
return 0, fmt.Errorf("SRTP read stream is closed")
}
return res.len, nil
}
}

// Close removes the ReadStream from the session and cleans up any associated state
func (r *ReadStreamSRTP) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if !r.isInited {
return fmt.Errorf("ReadStreamSRTP has not been inited")
}

select {
case <-r.isClosed:
return fmt.Errorf("ReadStreamSRTP is already closed")
default:
close(r.readRetCh)
r.session.removeReadStream(r.ssrc)
return nil
}
}

func (r *ReadStreamSRTP) init(child streamSession, ssrc uint32) error {
sessionSRTP, ok := child.(*SessionSRTP)

r.mu.Lock()
defer r.mu.Unlock()
if !ok {
return fmt.Errorf("ReadStreamSRTP init failed type assertion")
} else if r.isInited {
return fmt.Errorf("ReadStreamSRTP has already been inited")
}

r.session = sessionSRTP
r.ssrc = ssrc
r.readCh = make(chan []byte)
r.readRetCh = make(chan readResultSRTP)
r.isInited = true
r.isClosed = make(chan bool)
return nil
}

Expand Down

0 comments on commit a95c261

Please sign in to comment.