Skip to content

Commit

Permalink
Use atomic to avoid stale SRTP protection profile
Browse files Browse the repository at this point in the history
`state` is acccessed without lock in the FSM.
In some cases, that leads to stale values.
For example, `srtpProtectionProfile` is set in flight
handlers (differnt flight handlers in client and server).
But, when it is accessed via the API
`SelectedSRTPProtectionProfile`,
it gets a stale value as it appears that the two goroutines
are out-of-sync on that piece of shared memory.

This is a larger concern for use of `state`.
Ideally, either
- `state` should have a lock internally and all fields are accessed
  through methods.
- carefully split fields of `state` to ensure process access/sync.

Doing the smaller change here to address one field that has
seen stale value.
  • Loading branch information
boks1971 committed Nov 15, 2023
1 parent 5c0a7c1 commit 3cc07a0
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 15 deletions.
8 changes: 3 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,14 +374,12 @@ func (c *Conn) ConnectionState() State {

// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
c.lock.RLock()
defer c.lock.RUnlock()

if c.state.srtpProtectionProfile == 0 {
profile := c.state.getSRTPProtectionProfile()
if profile == 0 {
return 0, false
}

return c.state.srtpProtectionProfile, true
return profile, true
}

func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
Expand Down
2 changes: 1 addition & 1 deletion flight0handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak
if !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile
}
state.srtpProtectionProfile = profile
state.setSRTPProtectionProfile(profile)
case *extension.UseExtendedMasterSecret:
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
state.extendedMasterSecret = true
Expand Down
4 changes: 2 additions & 2 deletions flight3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh
if !found {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile
}
state.srtpProtectionProfile = profile
state.setSRTPProtectionProfile(profile)
case *extension.UseExtendedMasterSecret:
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
state.extendedMasterSecret = true
Expand All @@ -71,7 +71,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh
if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS
}
if len(cfg.localSRTPProtectionProfiles) > 0 && state.srtpProtectionProfile == 0 {
if len(cfg.localSRTPProtectionProfiles) > 0 && state.getSRTPProtectionProfile() == 0 {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension
}

Expand Down
4 changes: 2 additions & 2 deletions flight4bhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha
Supported: true,
})
}
if state.srtpProtectionProfile != 0 {
if state.getSRTPProtectionProfile() != 0 {
extensions = append(extensions, &extension.UseSRTP{
ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile},
ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
})
}

Expand Down
4 changes: 2 additions & 2 deletions flight4handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha
Supported: true,
})
}
if state.srtpProtectionProfile != 0 {
if state.getSRTPProtectionProfile() != 0 {
extensions = append(extensions, &extension.UseSRTP{
ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile},
ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
})
}
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
Expand Down
18 changes: 15 additions & 3 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type State struct {
masterSecret []byte
cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen

srtpProtectionProfile SRTPProtectionProfile // Negotiated SRTPProtectionProfile
srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile
PeerCertificates [][]byte
IdentityHint []byte
SessionID []byte
Expand Down Expand Up @@ -87,7 +87,7 @@ func (s *State) serialize() *serializedState {
SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]),
LocalRandom: localRnd,
RemoteRandom: remoteRnd,
SRTPProtectionProfile: uint16(s.srtpProtectionProfile),
SRTPProtectionProfile: uint16(s.getSRTPProtectionProfile()),
PeerCertificates: s.PeerCertificates,
IdentityHint: s.IdentityHint,
SessionID: s.SessionID,
Expand Down Expand Up @@ -123,7 +123,7 @@ func (s *State) deserialize(serialized serializedState) {
s.cipherSuite = cipherSuiteForID(CipherSuiteID(serialized.CipherSuiteID), nil)

atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber)
s.srtpProtectionProfile = SRTPProtectionProfile(serialized.SRTPProtectionProfile)
s.setSRTPProtectionProfile(SRTPProtectionProfile(serialized.SRTPProtectionProfile))

// Set remote certificate
s.PeerCertificates = serialized.PeerCertificates
Expand Down Expand Up @@ -214,3 +214,15 @@ func (s *State) getLocalEpoch() uint16 {
}
return 0
}

func (s *State) setSRTPProtectionProfile(profile SRTPProtectionProfile) {
s.srtpProtectionProfile.Store(profile)
}

func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile {
if val, ok := s.srtpProtectionProfile.Load().(SRTPProtectionProfile); ok {
return val
}

return 0
}

0 comments on commit 3cc07a0

Please sign in to comment.