forked from pion/dtls
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflight0handler.go
127 lines (105 loc) · 4.26 KB
/
flight0handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package dtls
import (
"context"
"crypto/rand"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/extension"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
)
func flight0Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite,
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
state.handshakeRecvSequence = seq
var clientHello *handshake.MessageClientHello
// Validate type
if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
}
if !clientHello.Version.Equal(protocol.Version1_2) {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
}
state.remoteRandom = clientHello.Random
cipherSuites := []CipherSuite{}
for _, id := range clientHello.CipherSuiteIDs {
if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil {
cipherSuites = append(cipherSuites, c)
}
}
if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
}
for _, val := range clientHello.Extensions {
switch e := val.(type) {
case *extension.SupportedEllipticCurves:
if len(e.EllipticCurves) == 0 {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves
}
state.namedCurve = e.EllipticCurves[0]
case *extension.UseSRTP:
profile, ok := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
if !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile
}
state.srtpProtectionProfile = profile
case *extension.UseExtendedMasterSecret:
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
state.extendedMasterSecret = true
}
case *extension.ServerName:
state.serverName = e.ServerName // remote server name
case *extension.ALPN:
state.peerSupportedProtocols = e.ProtocolNameList
}
}
if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS
}
if state.localKeypair == nil {
var err error
state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve)
if err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
}
}
return handleHelloResume(clientHello.SessionID, state, cfg, flight2)
}
func handleHelloResume(sessionID []byte, state *State, cfg *handshakeConfig, next flightVal) (flightVal, *alert.Alert, error) {
if len(sessionID) > 0 && cfg.sessionStore != nil {
if s, err := cfg.sessionStore.Get(sessionID); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
} else if s.ID != nil {
cfg.log.Tracef("[handshake] resume session: %x", sessionID)
state.SessionID = sessionID
state.masterSecret = s.Secret
if err := state.initCipherSuite(); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
clientRandom := state.localRandom.MarshalFixed()
cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
return flight4b, nil, nil
}
}
return next, nil, nil
}
func flight0Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
// Initialize
state.cookie = make([]byte, cookieLength)
if _, err := rand.Read(state.cookie); err != nil {
return nil, nil, err
}
var zeroEpoch uint16
state.localEpoch.Store(zeroEpoch)
state.remoteEpoch.Store(zeroEpoch)
state.namedCurve = defaultNamedCurve
if err := state.localRandom.Populate(); err != nil {
return nil, nil, err
}
return nil, nil, nil
}