diff --git a/internal/kubeclient/kubeclient_test.go b/internal/kubeclient/kubeclient_test.go index b5def24e9..6207f34be 100644 --- a/internal/kubeclient/kubeclient_test.go +++ b/internal/kubeclient/kubeclient_test.go @@ -5,6 +5,7 @@ package kubeclient import ( "context" + "crypto/x509" "fmt" "io" "net/http" @@ -949,14 +950,14 @@ func TestUnwrap(t *testing.T) { server, restConfig := fakekubeapi.Start(t, nil) - serverSubjects := server.Client().Transport.(*http.Transport).TLSClientConfig.RootCAs.Subjects() + serverCertPool := server.Client().Transport.(*http.Transport).TLSClientConfig.RootCAs t.Run("regular client", func(t *testing.T) { t.Parallel() // make sure to run in parallel to confirm that our client-go TLS cache busting works (i.e. assert no data races) regularClient := makeClient(t, restConfig, func(_ *rest.Config) {}) - testUnwrap(t, regularClient, serverSubjects, ptls.Secure) + testUnwrap(t, regularClient, serverCertPool, ptls.Secure) }) t.Run("exec client", func(t *testing.T) { @@ -971,7 +972,7 @@ func TestUnwrap(t *testing.T) { } }) - testUnwrap(t, execClient, serverSubjects, ptls.Secure) + testUnwrap(t, execClient, serverCertPool, ptls.Secure) }) t.Run("oidc client", func(t *testing.T) { @@ -987,7 +988,7 @@ func TestUnwrap(t *testing.T) { } }) - testUnwrap(t, oidcClient, serverSubjects, ptls.Secure) + testUnwrap(t, oidcClient, serverCertPool, ptls.Secure) }) t.Run("regular client with ptls.Default", func(t *testing.T) { @@ -995,7 +996,7 @@ func TestUnwrap(t *testing.T) { regularClient := makeClient(t, restConfig, func(_ *rest.Config) {}, WithTLSConfigFunc(ptls.Default)) - testUnwrap(t, regularClient, serverSubjects, ptls.Default) + testUnwrap(t, regularClient, serverCertPool, ptls.Default) }) t.Run("exec client with ptls.Default", func(t *testing.T) { @@ -1010,7 +1011,7 @@ func TestUnwrap(t *testing.T) { } }, WithTLSConfigFunc(ptls.Default)) - testUnwrap(t, execClient, serverSubjects, ptls.Default) + testUnwrap(t, execClient, serverCertPool, ptls.Default) }) t.Run("oidc client with ptls.Default", func(t *testing.T) { @@ -1026,11 +1027,11 @@ func TestUnwrap(t *testing.T) { } }, WithTLSConfigFunc(ptls.Default)) - testUnwrap(t, oidcClient, serverSubjects, ptls.Default) + testUnwrap(t, oidcClient, serverCertPool, ptls.Default) }) } -func testUnwrap(t *testing.T, client *Client, serverSubjects [][]byte, tlsConfigFuncForExpectedValues ptls.ConfigFunc) { +func testUnwrap(t *testing.T, client *Client, serverCertPool *x509.CertPool, tlsConfigFuncForExpectedValues ptls.ConfigFunc) { tests := []struct { name string rt http.RoundTripper @@ -1145,9 +1146,7 @@ func testUnwrap(t *testing.T, client *Client, serverSubjects [][]byte, tlsConfig require.Equal(t, ptlsConfig.MinVersion, tlsConfig.MinVersion) require.Equal(t, ptlsConfig.CipherSuites, tlsConfig.CipherSuites) require.Equal(t, ptlsConfig.NextProtos, tlsConfig.NextProtos) - - // x509.CertPool has some embedded functions that make it hard to compare so just look at the subjects - require.Equal(t, serverSubjects, tlsConfig.RootCAs.Subjects()) + require.True(t, serverCertPool.Equal(tlsConfig.RootCAs)) }) } }