@@ -4,8 +4,10 @@ import (
44 "crypto/tls"
55 "crypto/x509"
66 "encoding/pem"
7+ "fmt"
78 "io/ioutil"
89 "net"
10+ "sort"
911 "strings"
1012 "time"
1113
@@ -14,6 +16,22 @@ import (
1416 "github.com/pkg/errors"
1517)
1618
19+ type clientCertSubjectField string
20+
21+ const (
22+ clientCertSubjectCommonName = "CN"
23+ clientCertSubjectCountry = "C"
24+ clientCertSubjectProvince = "S"
25+ clientCertSubjectLocality = "L"
26+ clientCertSubjectOrganization = "O"
27+ clientCertSubjectOrganizationalUnit = "OU"
28+ )
29+
30+ type clientCertExpectedData struct {
31+ fields map [clientCertSubjectField ]string
32+ parts []string
33+ }
34+
1735var (
1836 defaultCurvePreferences = []tls.CurveID {
1937 tls .CurveP256 ,
@@ -120,9 +138,142 @@ func newTLSListenerConfig(conf *config.Config) (*tls.Config, error) {
120138 cfg .ClientCAs = clientCAs
121139 cfg .ClientAuth = tls .RequireAndVerifyClientCert
122140 }
141+
142+ cfg .VerifyPeerCertificate = tlsClientCertVerificationFunc (conf )
143+
123144 return cfg , nil
124145}
125146
147+ func tlsClientCertVerificationFunc (conf * config.Config ) func ([][]byte , [][]* x509.Certificate ) error {
148+ expectedData := getClientCertExpectedData (conf )
149+ return func (rawCerts [][]byte , verifiedChains [][]* x509.Certificate ) error {
150+ if conf .Proxy .TLS .ClientCert .ValidateSubject {
151+
152+ if len (expectedData .fields ) == 0 {
153+ return nil // nothing to validate
154+ }
155+
156+ for _ , chain := range verifiedChains {
157+ for _ , cert := range chain {
158+
159+ certificateAcceptable := true
160+
161+ for k , v := range expectedData .fields {
162+ switch k {
163+ case clientCertSubjectCommonName :
164+ if v != cert .Subject .CommonName {
165+ certificateAcceptable = false
166+ break
167+ }
168+ case clientCertSubjectCountry :
169+ currentValues := cert .Subject .Country
170+ sort .Strings (currentValues )
171+ if fmt .Sprintf ("%v" , currentValues ) != v {
172+ certificateAcceptable = false
173+ break
174+ }
175+ case clientCertSubjectProvince :
176+ currentValues := cert .Subject .Province
177+ sort .Strings (currentValues )
178+ if fmt .Sprintf ("%v" , currentValues ) != v {
179+ certificateAcceptable = false
180+ break
181+ }
182+ case clientCertSubjectLocality :
183+ currentValues := cert .Subject .Locality
184+ sort .Strings (currentValues )
185+ if fmt .Sprintf ("%v" , currentValues ) != v {
186+ certificateAcceptable = false
187+ break
188+ }
189+ case clientCertSubjectOrganization :
190+ currentValues := cert .Subject .Organization
191+ sort .Strings (currentValues )
192+ if fmt .Sprintf ("%v" , currentValues ) != v {
193+ certificateAcceptable = false
194+ break
195+ }
196+ case clientCertSubjectOrganizationalUnit :
197+ currentValues := cert .Subject .OrganizationalUnit
198+ sort .Strings (currentValues )
199+ if fmt .Sprintf ("%v" , currentValues ) != v {
200+ certificateAcceptable = false
201+ break
202+ }
203+ }
204+ }
205+
206+ if certificateAcceptable {
207+ return nil
208+ }
209+
210+ }
211+ }
212+
213+ return fmt .Errorf ("tls: no client certificate presented required subject '%s'" , strings .Join (expectedData .parts , "/" ))
214+
215+ }
216+ return nil
217+ }
218+ }
219+
220+ func getClientCertExpectedData (conf * config.Config ) * clientCertExpectedData {
221+
222+ expectedFields := map [clientCertSubjectField ]string {}
223+ expectedParts := []string {"s:" } // these are calculated here because the order is relevant to us
224+ values := []string {}
225+
226+ if conf .Proxy .TLS .ClientCert .Subject .CommonName != "" {
227+ expectedFields [clientCertSubjectCommonName ] = conf .Proxy .TLS .ClientCert .Subject .CommonName
228+ expectedParts = append (expectedParts , fmt .Sprintf ("%s=%s" , clientCertSubjectCommonName , expectedFields [clientCertSubjectCommonName ]))
229+ }
230+ values = removeEmptyStrings (conf .Proxy .TLS .ClientCert .Subject .Country )
231+ if len (values ) > 0 {
232+ sort .Strings (values )
233+ expectedFields [clientCertSubjectCountry ] = fmt .Sprintf ("%v" , values )
234+ expectedParts = append (expectedParts , fmt .Sprintf ("%s=%s" , clientCertSubjectCountry , expectedFields [clientCertSubjectCountry ]))
235+ }
236+ values = removeEmptyStrings (conf .Proxy .TLS .ClientCert .Subject .Province )
237+ if len (values ) > 0 {
238+ sort .Strings (values )
239+ expectedFields [clientCertSubjectProvince ] = fmt .Sprintf ("%v" , values )
240+ expectedParts = append (expectedParts , fmt .Sprintf ("%s=%s" , clientCertSubjectProvince , expectedFields [clientCertSubjectProvince ]))
241+ }
242+ values = removeEmptyStrings (conf .Proxy .TLS .ClientCert .Subject .Locality )
243+ if len (values ) > 0 {
244+ sort .Strings (values )
245+ expectedFields [clientCertSubjectLocality ] = fmt .Sprintf ("%v" , values )
246+ expectedParts = append (expectedParts , fmt .Sprintf ("%s=%s" , clientCertSubjectLocality , expectedFields [clientCertSubjectLocality ]))
247+ }
248+ values = removeEmptyStrings (conf .Proxy .TLS .ClientCert .Subject .Organization )
249+ if len (values ) > 0 {
250+ sort .Strings (values )
251+ expectedFields [clientCertSubjectOrganization ] = fmt .Sprintf ("%v" , values )
252+ expectedParts = append (expectedParts , fmt .Sprintf ("%s=%s" , clientCertSubjectOrganization , expectedFields [clientCertSubjectOrganization ]))
253+ }
254+ values = removeEmptyStrings (conf .Proxy .TLS .ClientCert .Subject .OrganizationalUnit )
255+ if len (values ) > 0 {
256+ sort .Strings (values )
257+ expectedFields [clientCertSubjectOrganizationalUnit ] = fmt .Sprintf ("%v" , values )
258+ expectedParts = append (expectedParts , fmt .Sprintf ("%s=%s" , clientCertSubjectOrganizationalUnit , expectedFields [clientCertSubjectOrganizationalUnit ]))
259+ }
260+ return & clientCertExpectedData {
261+ parts : expectedParts ,
262+ fields : expectedFields ,
263+ }
264+ }
265+
266+ func removeEmptyStrings (input []string ) []string {
267+ output := []string {}
268+ for _ , value := range input {
269+ if value == "" {
270+ continue
271+ }
272+ output = append (output , value )
273+ }
274+ return output
275+ }
276+
126277func getCipherSuites (enabledCipherSuites []string ) ([]uint16 , error ) {
127278 suites := make ([]uint16 , 0 )
128279 for _ , suite := range enabledCipherSuites {
0 commit comments