Skip to content

Commit 1c0fa3e

Browse files
authored
Check if TLS certificate and key file have been modified (#345)
* Check hash of cert and key file Signed-off-by: Levi Harrison <git@leviharrison.dev> Signed-off-by: Simon Pasquier <spasquie@redhat.com>
1 parent 54e041d commit 1c0fa3e

File tree

2 files changed

+184
-22
lines changed

2 files changed

+184
-22
lines changed

config/http_config.go

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
540540
return newRT(tlsConfig)
541541
}
542542

543-
return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, newRT)
543+
return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, cfg.TLSConfig.CertFile, cfg.TLSConfig.KeyFile, newRT)
544544
}
545545

546546
type authorizationCredentialsRoundTripper struct {
@@ -709,7 +709,7 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
709709
if len(rt.config.TLSConfig.CAFile) == 0 {
710710
t, _ = tlsTransport(tlsConfig)
711711
} else {
712-
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, tlsTransport)
712+
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, rt.config.TLSConfig.CertFile, rt.config.TLSConfig.KeyFile, tlsTransport)
713713
if err != nil {
714714
return nil, err
715715
}
@@ -838,12 +838,39 @@ func (c *TLSConfig) SetDirectory(dir string) {
838838
c.KeyFile = JoinDir(dir, c.KeyFile)
839839
}
840840

841+
// UnmarshalYAML implements the yaml.Unmarshaler interface.
842+
func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
843+
type plain TLSConfig
844+
return unmarshal((*plain)(c))
845+
}
846+
847+
// readCertAndKey reads the cert and key files from the disk.
848+
func readCertAndKey(certFile, keyFile string) ([]byte, []byte, error) {
849+
certData, err := ioutil.ReadFile(certFile)
850+
if err != nil {
851+
return nil, nil, err
852+
}
853+
854+
keyData, err := ioutil.ReadFile(keyFile)
855+
if err != nil {
856+
return nil, nil, err
857+
}
858+
859+
return certData, keyData, nil
860+
}
861+
841862
// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate.
842-
func (c *TLSConfig) getClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
843-
cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
863+
func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
864+
certData, keyData, err := readCertAndKey(c.CertFile, c.KeyFile)
865+
if err != nil {
866+
return nil, fmt.Errorf("unable to read specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
867+
}
868+
869+
cert, err := tls.X509KeyPair(certData, keyData)
844870
if err != nil {
845871
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
846872
}
873+
847874
return &cert, nil
848875
}
849876

@@ -869,23 +896,30 @@ func updateRootCA(cfg *tls.Config, b []byte) bool {
869896
// tlsRoundTripper is a RoundTripper that updates automatically its TLS
870897
// configuration whenever the content of the CA file changes.
871898
type tlsRoundTripper struct {
872-
caFile string
899+
caFile string
900+
certFile string
901+
keyFile string
902+
873903
// newRT returns a new RoundTripper.
874904
newRT func(*tls.Config) (http.RoundTripper, error)
875905

876-
mtx sync.RWMutex
877-
rt http.RoundTripper
878-
hashCAFile []byte
879-
tlsConfig *tls.Config
906+
mtx sync.RWMutex
907+
rt http.RoundTripper
908+
hashCAFile []byte
909+
hashCertFile []byte
910+
hashKeyFile []byte
911+
tlsConfig *tls.Config
880912
}
881913

882914
func NewTLSRoundTripper(
883915
cfg *tls.Config,
884-
caFile string,
916+
caFile, certFile, keyFile string,
885917
newRT func(*tls.Config) (http.RoundTripper, error),
886918
) (http.RoundTripper, error) {
887919
t := &tlsRoundTripper{
888920
caFile: caFile,
921+
certFile: certFile,
922+
keyFile: keyFile,
889923
newRT: newRT,
890924
tlsConfig: cfg,
891925
}
@@ -895,33 +929,44 @@ func NewTLSRoundTripper(
895929
return nil, err
896930
}
897931
t.rt = rt
898-
_, t.hashCAFile, err = t.getCAWithHash()
932+
_, t.hashCAFile, t.hashCertFile, t.hashKeyFile, err = t.getTLSFilesWithHash()
899933
if err != nil {
900934
return nil, err
901935
}
902936

903937
return t, nil
904938
}
905939

906-
func (t *tlsRoundTripper) getCAWithHash() ([]byte, []byte, error) {
907-
b, err := readCAFile(t.caFile)
940+
func (t *tlsRoundTripper) getTLSFilesWithHash() ([]byte, []byte, []byte, []byte, error) {
941+
b1, err := readCAFile(t.caFile)
908942
if err != nil {
909-
return nil, nil, err
943+
return nil, nil, nil, nil, err
944+
}
945+
h1 := sha256.Sum256(b1)
946+
947+
var h2, h3 [32]byte
948+
if t.certFile != "" {
949+
b2, b3, err := readCertAndKey(t.certFile, t.keyFile)
950+
if err != nil {
951+
return nil, nil, nil, nil, err
952+
}
953+
h2, h3 = sha256.Sum256(b2), sha256.Sum256(b3)
910954
}
911-
h := sha256.Sum256(b)
912-
return b, h[:], nil
913955

956+
return b1, h1[:], h2[:], h3[:], nil
914957
}
915958

916959
// RoundTrip implements the http.RoundTrip interface.
917960
func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
918-
b, h, err := t.getCAWithHash()
961+
caData, caHash, certHash, keyHash, err := t.getTLSFilesWithHash()
919962
if err != nil {
920963
return nil, err
921964
}
922965

923966
t.mtx.RLock()
924-
equal := bytes.Equal(h[:], t.hashCAFile)
967+
equal := bytes.Equal(caHash[:], t.hashCAFile) &&
968+
bytes.Equal(certHash[:], t.hashCertFile) &&
969+
bytes.Equal(keyHash[:], t.hashKeyFile)
925970
rt := t.rt
926971
t.mtx.RUnlock()
927972
if equal {
@@ -930,8 +975,10 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
930975
}
931976

932977
// Create a new RoundTripper.
978+
// The cert and key files are read separately by the client
979+
// using GetClientCertificate.
933980
tlsConfig := t.tlsConfig.Clone()
934-
if !updateRootCA(tlsConfig, b) {
981+
if !updateRootCA(tlsConfig, caData) {
935982
return nil, fmt.Errorf("unable to use specified CA cert %s", t.caFile)
936983
}
937984
rt, err = t.newRT(tlsConfig)
@@ -942,7 +989,9 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
942989

943990
t.mtx.Lock()
944991
t.rt = rt
945-
t.hashCAFile = h[:]
992+
t.hashCAFile = caHash[:]
993+
t.hashCertFile = certHash[:]
994+
t.hashKeyFile = keyHash[:]
946995
t.mtx.Unlock()
947996

948997
return rt.RoundTrip(req)

config/http_config_test.go

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -718,15 +718,15 @@ func TestTLSConfigInvalidCA(t *testing.T) {
718718
KeyFile: ClientKeyNoPassPath,
719719
ServerName: "",
720720
InsecureSkipVerify: false},
721-
errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath),
721+
errorMessage: fmt.Sprintf("unable to read specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath),
722722
}, {
723723
configTLSConfig: TLSConfig{
724724
CAFile: "",
725725
CertFile: ClientCertificatePath,
726726
KeyFile: MissingKey,
727727
ServerName: "",
728728
InsecureSkipVerify: false},
729-
errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey),
729+
errorMessage: fmt.Sprintf("unable to read specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey),
730730
},
731731
}
732732

@@ -1532,3 +1532,116 @@ func TestOAuth2Proxy(t *testing.T) {
15321532
t.Errorf("Error loading OAuth2 client config: %v", err)
15331533
}
15341534
}
1535+
1536+
func TestModifyTLSCertificates(t *testing.T) {
1537+
bs := getCertificateBlobs(t)
1538+
1539+
tmpDir, err := ioutil.TempDir("", "modifytlscertificates")
1540+
if err != nil {
1541+
t.Fatal("Failed to create tmp dir", err)
1542+
}
1543+
defer os.RemoveAll(tmpDir)
1544+
ca, cert, key := filepath.Join(tmpDir, "ca"), filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key")
1545+
1546+
handler := func(w http.ResponseWriter, r *http.Request) {
1547+
fmt.Fprint(w, ExpectedMessage)
1548+
}
1549+
testServer, err := newTestServer(handler)
1550+
if err != nil {
1551+
t.Fatal(err.Error())
1552+
}
1553+
defer testServer.Close()
1554+
1555+
tests := []struct {
1556+
ca string
1557+
cert string
1558+
key string
1559+
1560+
errMsg string
1561+
1562+
modification func()
1563+
}{
1564+
{
1565+
ca: ClientCertificatePath,
1566+
cert: ClientCertificatePath,
1567+
key: ClientKeyNoPassPath,
1568+
1569+
errMsg: "certificate signed by unknown authority",
1570+
1571+
modification: func() { writeCertificate(bs, TLSCAChainPath, ca) },
1572+
},
1573+
{
1574+
ca: TLSCAChainPath,
1575+
cert: WrongClientCertPath,
1576+
key: ClientKeyNoPassPath,
1577+
1578+
errMsg: "private key does not match public key",
1579+
1580+
modification: func() { writeCertificate(bs, ClientCertificatePath, cert) },
1581+
},
1582+
{
1583+
ca: TLSCAChainPath,
1584+
cert: ClientCertificatePath,
1585+
key: WrongClientCertPath,
1586+
1587+
errMsg: "found a certificate rather than a key in the PEM for the private key",
1588+
1589+
modification: func() { writeCertificate(bs, ClientKeyNoPassPath, key) },
1590+
},
1591+
}
1592+
1593+
cfg := HTTPClientConfig{
1594+
TLSConfig: TLSConfig{
1595+
CAFile: ca,
1596+
CertFile: cert,
1597+
KeyFile: key,
1598+
InsecureSkipVerify: false},
1599+
}
1600+
1601+
var c *http.Client
1602+
for i, tc := range tests {
1603+
t.Run(strconv.Itoa(i), func(t *testing.T) {
1604+
writeCertificate(bs, tc.ca, ca)
1605+
writeCertificate(bs, tc.cert, cert)
1606+
writeCertificate(bs, tc.key, key)
1607+
if c == nil {
1608+
c, err = NewClientFromConfig(cfg, "test")
1609+
if err != nil {
1610+
t.Fatalf("Error creating HTTP Client: %v", err)
1611+
}
1612+
}
1613+
1614+
req, err := http.NewRequest(http.MethodGet, testServer.URL, nil)
1615+
if err != nil {
1616+
t.Fatalf("Error creating HTTP request: %v", err)
1617+
}
1618+
1619+
r, err := c.Do(req)
1620+
if err == nil {
1621+
r.Body.Close()
1622+
t.Fatalf("Could connect to the test server.")
1623+
}
1624+
if !strings.Contains(err.Error(), tc.errMsg) {
1625+
t.Fatalf("Expected error message to contain %q, got %q", tc.errMsg, err)
1626+
}
1627+
1628+
tc.modification()
1629+
1630+
r, err = c.Do(req)
1631+
if err != nil {
1632+
t.Fatalf("Expected no error, got %q", err)
1633+
}
1634+
1635+
b, err := ioutil.ReadAll(r.Body)
1636+
r.Body.Close()
1637+
if err != nil {
1638+
t.Errorf("Can't read the server response body")
1639+
}
1640+
1641+
got := strings.TrimSpace(string(b))
1642+
if ExpectedMessage != got {
1643+
t.Errorf("The expected message %q differs from the obtained message %q", ExpectedMessage, got)
1644+
}
1645+
})
1646+
}
1647+
}

0 commit comments

Comments
 (0)