Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use atomic to avoid stale SRTP protection profile #595

Merged
merged 1 commit into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,14 +394,12 @@ func (c *Conn) ConnectionState() State {

// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
c.lock.RLock()
defer c.lock.RUnlock()

if c.state.srtpProtectionProfile == 0 {
profile := c.state.getSRTPProtectionProfile()
if profile == 0 {
return 0, false
}

return c.state.srtpProtectionProfile, true
return profile, true
}

func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
Expand Down
2 changes: 1 addition & 1 deletion flight0handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak
if !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile
}
state.srtpProtectionProfile = profile
state.setSRTPProtectionProfile(profile)
case *extension.UseExtendedMasterSecret:
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
state.extendedMasterSecret = true
Expand Down
4 changes: 2 additions & 2 deletions flight3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh
if !found {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile
}
state.srtpProtectionProfile = profile
state.setSRTPProtectionProfile(profile)
case *extension.UseExtendedMasterSecret:
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
state.extendedMasterSecret = true
Expand All @@ -83,7 +83,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh
if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS
}
if len(cfg.localSRTPProtectionProfiles) > 0 && state.srtpProtectionProfile == 0 {
if len(cfg.localSRTPProtectionProfiles) > 0 && state.getSRTPProtectionProfile() == 0 {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension
}

Expand Down
4 changes: 2 additions & 2 deletions flight4bhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha
Supported: true,
})
}
if state.srtpProtectionProfile != 0 {
if state.getSRTPProtectionProfile() != 0 {
extensions = append(extensions, &extension.UseSRTP{
ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile},
ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
})
}

Expand Down
4 changes: 2 additions & 2 deletions flight4handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha
Supported: true,
})
}
if state.srtpProtectionProfile != 0 {
if state.getSRTPProtectionProfile() != 0 {
extensions = append(extensions, &extension.UseSRTP{
ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile},
ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
})
}
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
Expand Down
18 changes: 15 additions & 3 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type State struct {
cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen
CipherSuiteID CipherSuiteID

srtpProtectionProfile SRTPProtectionProfile // Negotiated SRTPProtectionProfile
srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile
PeerCertificates [][]byte
IdentityHint []byte
SessionID []byte
Expand Down Expand Up @@ -106,7 +106,7 @@ func (s *State) serialize() *serializedState {
SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]),
LocalRandom: localRnd,
RemoteRandom: remoteRnd,
SRTPProtectionProfile: uint16(s.srtpProtectionProfile),
SRTPProtectionProfile: uint16(s.getSRTPProtectionProfile()),
PeerCertificates: s.PeerCertificates,
IdentityHint: s.IdentityHint,
SessionID: s.SessionID,
Expand Down Expand Up @@ -145,7 +145,7 @@ func (s *State) deserialize(serialized serializedState) {
s.cipherSuite = cipherSuiteForID(s.CipherSuiteID, nil)

atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber)
s.srtpProtectionProfile = SRTPProtectionProfile(serialized.SRTPProtectionProfile)
s.setSRTPProtectionProfile(SRTPProtectionProfile(serialized.SRTPProtectionProfile))

// Set remote certificate
s.PeerCertificates = serialized.PeerCertificates
Expand Down Expand Up @@ -242,3 +242,15 @@ func (s *State) getLocalEpoch() uint16 {
}
return 0
}

func (s *State) setSRTPProtectionProfile(profile SRTPProtectionProfile) {
s.srtpProtectionProfile.Store(profile)
}

func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile {
if val, ok := s.srtpProtectionProfile.Load().(SRTPProtectionProfile); ok {
return val
}

return 0
}