|
19 | 19 | package advancedtls
|
20 | 20 |
|
21 | 21 | import (
|
| 22 | + "bytes" |
22 | 23 | "context"
|
23 | 24 | "crypto/tls"
|
24 | 25 | "crypto/x509"
|
@@ -896,6 +897,76 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
896 | 897 | t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr,
|
897 | 898 | clientAuthInfo, serverAuthInfo)
|
898 | 899 | }
|
| 900 | + serverVerifiedChains := serverAuthInfo.(credentials.TLSInfo).State.VerifiedChains |
| 901 | + if test.serverMutualTLS && !test.serverExpectError { |
| 902 | + if len(serverVerifiedChains) == 0 { |
| 903 | + t.Fatalf("server verified chains is empty") |
| 904 | + } |
| 905 | + var clientCert *tls.Certificate |
| 906 | + if len(test.clientCert) > 0 { |
| 907 | + clientCert = &test.clientCert[0] |
| 908 | + } else if test.clientGetCert != nil { |
| 909 | + cert, _ := test.clientGetCert(&tls.CertificateRequestInfo{}) |
| 910 | + clientCert = cert |
| 911 | + } else if test.clientIdentityProvider != nil { |
| 912 | + km, _ := test.clientIdentityProvider.KeyMaterial(nil) |
| 913 | + clientCert = &km.Certs[0] |
| 914 | + } |
| 915 | + if !bytes.Equal((*serverVerifiedChains[0][0]).Raw, clientCert.Certificate[0]) { |
| 916 | + t.Fatal("server verifiedChains leaf cert doesn't match client cert") |
| 917 | + } |
| 918 | + |
| 919 | + var serverRoot *x509.CertPool |
| 920 | + if test.serverRoot != nil { |
| 921 | + serverRoot = test.serverRoot |
| 922 | + } else if test.serverGetRoot != nil { |
| 923 | + result, _ := test.serverGetRoot(&GetRootCAsParams{}) |
| 924 | + serverRoot = result.TrustCerts |
| 925 | + } else if test.serverRootProvider != nil { |
| 926 | + km, _ := test.serverRootProvider.KeyMaterial(nil) |
| 927 | + serverRoot = km.Roots |
| 928 | + } |
| 929 | + serverVerifiedChainsCp := x509.NewCertPool() |
| 930 | + serverVerifiedChainsCp.AddCert(serverVerifiedChains[0][len(serverVerifiedChains[0])-1]) |
| 931 | + if !serverVerifiedChainsCp.Equal(serverRoot) { |
| 932 | + t.Fatalf("server verified chain hierarchy doesn't match") |
| 933 | + } |
| 934 | + } |
| 935 | + clientVerifiedChains := clientAuthInfo.(credentials.TLSInfo).State.VerifiedChains |
| 936 | + if test.serverMutualTLS && !test.clientExpectHandshakeError { |
| 937 | + if len(clientVerifiedChains) == 0 { |
| 938 | + t.Fatalf("client verified chains is empty") |
| 939 | + } |
| 940 | + var serverCert *tls.Certificate |
| 941 | + if len(test.serverCert) > 0 { |
| 942 | + serverCert = &test.serverCert[0] |
| 943 | + } else if test.serverGetCert != nil { |
| 944 | + cert, _ := test.serverGetCert(&tls.ClientHelloInfo{}) |
| 945 | + serverCert = cert[0] |
| 946 | + } else if test.serverIdentityProvider != nil { |
| 947 | + km, _ := test.serverIdentityProvider.KeyMaterial(nil) |
| 948 | + serverCert = &km.Certs[0] |
| 949 | + } |
| 950 | + if !bytes.Equal((*clientVerifiedChains[0][0]).Raw, serverCert.Certificate[0]) { |
| 951 | + t.Fatal("client verifiedChains leaf cert doesn't match server cert") |
| 952 | + } |
| 953 | + |
| 954 | + var clientRoot *x509.CertPool |
| 955 | + if test.clientRoot != nil { |
| 956 | + clientRoot = test.clientRoot |
| 957 | + } else if test.clientGetRoot != nil { |
| 958 | + result, _ := test.clientGetRoot(&GetRootCAsParams{}) |
| 959 | + clientRoot = result.TrustCerts |
| 960 | + } else if test.clientRootProvider != nil { |
| 961 | + km, _ := test.clientRootProvider.KeyMaterial(nil) |
| 962 | + clientRoot = km.Roots |
| 963 | + } |
| 964 | + clientVerifiedChainsCp := x509.NewCertPool() |
| 965 | + clientVerifiedChainsCp.AddCert(clientVerifiedChains[0][len(clientVerifiedChains[0])-1]) |
| 966 | + if !clientVerifiedChainsCp.Equal(clientRoot) { |
| 967 | + t.Fatalf("client verified chain hierarchy doesn't match") |
| 968 | + } |
| 969 | + } |
899 | 970 | })
|
900 | 971 | }
|
901 | 972 | }
|
|
0 commit comments