Skip to content

Commit ce997a6

Browse files
committed
add unit testing and documentation
1 parent 6653a37 commit ce997a6

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

security/advancedtls/advancedtls.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,11 @@ func (c *advancedTLSCreds) OverrideServerName(serverNameOverride string) error {
556556
// 1. does not have a good support on root cert reloading.
557557
// 2. will ignore basic certificate check when setting InsecureSkipVerify
558558
// to true.
559+
//
560+
// peerVerifiedChains(output param): verified chain of certs from leaf to the
561+
// trust cert that the peer trusts.
562+
// 1. For server it is, client certs + Root ca that the server trusts
563+
// 2. For client it is, server certs + Root ca that the client trusts
559564
func buildVerifyFunc(c *advancedTLSCreds,
560565
serverName string,
561566
rawConn net.Conn,
@@ -637,7 +642,9 @@ func buildVerifyFunc(c *advancedTLSCreds,
637642
VerifiedChains: chains,
638643
Leaf: leafCert,
639644
})
640-
return err
645+
if err != nil {
646+
return err
647+
}
641648
}
642649
*peerVerifiedChains = chains
643650
return nil

security/advancedtls/advancedtls_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package advancedtls
2020

2121
import (
22+
"bytes"
2223
"context"
2324
"crypto/tls"
2425
"crypto/x509"
@@ -896,6 +897,76 @@ func (s) TestClientServerHandshake(t *testing.T) {
896897
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr,
897898
clientAuthInfo, serverAuthInfo)
898899
}
900+
serverVerifiedChains := serverAuthInfo.(credentials.TLSInfo).State.VerifiedChains
901+
if test.serverMutualTLS && !test.serverExpectError {
902+
if len(serverVerifiedChains) == 0 {
903+
t.Fatalf("server verified chains is empty")
904+
}
905+
var clientCert *tls.Certificate
906+
if len(test.clientCert) > 0 {
907+
clientCert = &test.clientCert[0]
908+
} else if test.clientGetCert != nil {
909+
cert, _ := test.clientGetCert(&tls.CertificateRequestInfo{})
910+
clientCert = cert
911+
} else if test.clientIdentityProvider != nil {
912+
km, _ := test.clientIdentityProvider.KeyMaterial(nil)
913+
clientCert = &km.Certs[0]
914+
}
915+
if !bytes.Equal((*serverVerifiedChains[0][0]).Raw, clientCert.Certificate[0]) {
916+
t.Fatal("server verifiedChains leaf cert doesn't match client cert")
917+
}
918+
919+
var serverRoot *x509.CertPool
920+
if test.serverRoot != nil {
921+
serverRoot = test.serverRoot
922+
} else if test.serverGetRoot != nil {
923+
result, _ := test.serverGetRoot(&GetRootCAsParams{})
924+
serverRoot = result.TrustCerts
925+
} else if test.serverRootProvider != nil {
926+
km, _ := test.serverRootProvider.KeyMaterial(nil)
927+
serverRoot = km.Roots
928+
}
929+
serverVerifiedChainsCp := x509.NewCertPool()
930+
serverVerifiedChainsCp.AddCert(serverVerifiedChains[0][len(serverVerifiedChains[0])-1])
931+
if !serverVerifiedChainsCp.Equal(serverRoot) {
932+
t.Fatalf("server verified chain hierarchy doesn't match")
933+
}
934+
}
935+
clientVerifiedChains := clientAuthInfo.(credentials.TLSInfo).State.VerifiedChains
936+
if test.serverMutualTLS && !test.clientExpectHandshakeError {
937+
if len(clientVerifiedChains) == 0 {
938+
t.Fatalf("client verified chains is empty")
939+
}
940+
var serverCert *tls.Certificate
941+
if len(test.serverCert) > 0 {
942+
serverCert = &test.serverCert[0]
943+
} else if test.serverGetCert != nil {
944+
cert, _ := test.serverGetCert(&tls.ClientHelloInfo{})
945+
serverCert = cert[0]
946+
} else if test.serverIdentityProvider != nil {
947+
km, _ := test.serverIdentityProvider.KeyMaterial(nil)
948+
serverCert = &km.Certs[0]
949+
}
950+
if !bytes.Equal((*clientVerifiedChains[0][0]).Raw, serverCert.Certificate[0]) {
951+
t.Fatal("client verifiedChains leaf cert doesn't match server cert")
952+
}
953+
954+
var clientRoot *x509.CertPool
955+
if test.clientRoot != nil {
956+
clientRoot = test.clientRoot
957+
} else if test.clientGetRoot != nil {
958+
result, _ := test.clientGetRoot(&GetRootCAsParams{})
959+
clientRoot = result.TrustCerts
960+
} else if test.clientRootProvider != nil {
961+
km, _ := test.clientRootProvider.KeyMaterial(nil)
962+
clientRoot = km.Roots
963+
}
964+
clientVerifiedChainsCp := x509.NewCertPool()
965+
clientVerifiedChainsCp.AddCert(clientVerifiedChains[0][len(clientVerifiedChains[0])-1])
966+
if !clientVerifiedChainsCp.Equal(clientRoot) {
967+
t.Fatalf("client verified chain hierarchy doesn't match")
968+
}
969+
}
899970
})
900971
}
901972
}

0 commit comments

Comments
 (0)