Skip to content
This repository was archived by the owner on Mar 27, 2024. It is now read-only.

Add two options to handle self-signed certificates registries #327

Merged
merged 1 commit into from
Mar 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions boilerplate/boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def get_regexs():
regexs = {}
# Search for "YEAR" which exists in the boilerplate, but shouldn't in the real thing
regexs["year"] = re.compile( 'YEAR' )
# dates can be 2014, 2015, 2016, 2017, or 2018, company holder names can be anything
regexs["date"] = re.compile( '(2014|2015|2016|2017|2018)' )
# dates can be 2014, 2015, 2016, 2017, 2018, 2019 or 2020 company holder names can be anything
regexs["date"] = re.compile( '(2014|2015|2016|2017|2018|2019|2020)' )
# strip // +build \n\n build constraints
regexs["go_build_constraints"] = re.compile(r"^(// \+build.*\n)+\n", re.MULTILINE)
# strip #!.* from shell scripts
Expand Down
50 changes: 40 additions & 10 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ import (
var json bool

var save bool
var types diffTypes
var types multiValueFlag
var noCache bool

var outputFile string
var forceWrite bool
var cacheDir string
var LogLevel string
var format string
var skipTsVerifyRegistries multiValueFlag
var registriesCertificates keyValueFlag

const containerDiffEnvCacheDir = "CONTAINER_DIFF_CACHEDIR"

Expand All @@ -69,6 +71,7 @@ Tarballs can also be specified by simply providing the path to the .tar, .tar.gz
os.Exit(1)
}
logrus.SetLevel(ll)
pkgutil.ConfigureTLS(skipTsVerifyRegistries, registriesCertificates)
},
}

Expand Down Expand Up @@ -147,6 +150,7 @@ func getImage(imageName string) (pkgutil.Image, error) {
return pkgutil.Image{}, err
}
}

return pkgutil.GetImage(imageName, includeLayers(), cachePath)
}

Expand Down Expand Up @@ -193,33 +197,59 @@ func getWriter(outputFile string) (io.Writer, error) {
func init() {
RootCmd.PersistentFlags().StringVarP(&LogLevel, "verbosity", "v", "warning", "This flag controls the verbosity of container-diff.")
RootCmd.PersistentFlags().StringVarP(&format, "format", "", "", "Format to output diff in.")
RootCmd.PersistentFlags().VarP(&skipTsVerifyRegistries, "skip-tls-verify-registry", "", "Insecure registry ignoring TLS verify to push and pull. Set it repeatedly for multiple registries.")
registriesCertificates = make(keyValueFlag)
RootCmd.PersistentFlags().VarP(&registriesCertificates, "registry-certificate", "", "Use the provided certificate for TLS communication with the given registry. Expected format is 'my.registry=/path/to/the/server/certificate'.")
pflag.CommandLine.AddGoFlagSet(goflag.CommandLine)
}

// Define a type named "diffSlice" as a slice of strings
type diffTypes []string
// Define a type named "multiValueFlag" as a slice of strings
type multiValueFlag []string

// Now, for our new type, implement the two methods of
// the flag.Value interface...
// The first method is String() string
func (d *diffTypes) String() string {
return strings.Join(*d, ",")
func (f *multiValueFlag) String() string {
return strings.Join(*f, ",")
}

// The second method is Set(value string) error
func (d *diffTypes) Set(value string) error {
func (f *multiValueFlag) Set(value string) error {
// Dedupe repeated elements.
for _, t := range *d {
for _, t := range *f {
if t == value {
return nil
}
}
*d = append(*d, value)
*f = append(*f, value)
return nil
}

func (f *multiValueFlag) Type() string {
return "multiValueFlag"
}

type keyValueFlag map[string]string

func (f *keyValueFlag) String() string {
var result []string
for key, value := range *f {
result = append(result, fmt.Sprintf("%s=%s", key, value))
}
return strings.Join(result, ",")
}

func (f *keyValueFlag) Set(value string) error {
parts := strings.SplitN(value, "=", 2)
if len(parts) < 2 {
return fmt.Errorf("invalid argument value. expect key=value, got %s", value)
}
(*f)[parts[0]] = parts[1]
return nil
}

func (d *diffTypes) Type() string {
return "Diff Types"
func (f *keyValueFlag) Type() string {
return "keyValueFlag"
}

func addSharedFlags(cmd *cobra.Command) {
Expand Down
29 changes: 29 additions & 0 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"os"
"path"
"path/filepath"
"reflect"
"testing"

homedir "github.com/mitchellh/go-homedir"
Expand Down Expand Up @@ -94,3 +95,31 @@ func TestCacheDir(t *testing.T) {
)
}
}

func TestMultiValueFlag_Set_shouldDedupeRepeatedArguments(t *testing.T) {
var arg multiValueFlag
arg.Set("value1")
arg.Set("value2")
arg.Set("value3")

arg.Set("value2")
if len(arg) != 3 || reflect.DeepEqual(arg, []string{"value1", "value2", "value3"}) {
t.Error("multiValueFlag should dedupe repeated arguments")
}
}

func Test_KeyValueArg_Set_shouldSplitArgument(t *testing.T) {
arg := make(keyValueFlag)
arg.Set("key=value")
if arg["key"] != "value" {
t.Error("Invalid split. key=value should be split to key=>value")
}
}

func Test_KeyValueArg_Set_shouldAcceptEqualAsValue(t *testing.T) {
arg := make(keyValueFlag)
arg.Set("key=value=something")
if arg["key"] != "value=something" {
t.Error("Invalid split. key=value=something should be split to key=>value=something")
}
}
3 changes: 1 addition & 2 deletions pkg/util/image_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"regexp"
Expand Down Expand Up @@ -116,7 +115,7 @@ func GetImage(imageName string, includeLayers bool, cacheDir string) (Image, err
return Image{}, errors.Wrap(err, "resolving auth")
}
start := time.Now()
img, err = remote.Image(ref, remote.WithAuth(auth), remote.WithTransport(http.DefaultTransport))
img, err = remote.Image(ref, remote.WithAuth(auth), remote.WithTransport(BuildTransport(ref.Context().Registry)))
if err != nil {
return Image{}, errors.Wrap(err, "retrieving remote image")
}
Expand Down
101 changes: 101 additions & 0 deletions pkg/util/transport_builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
Copyright 2020 Google, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package util

import (
"crypto/tls"
"crypto/x509"
"github.com/sirupsen/logrus"
"io/ioutil"
"net"
"net/http"
"time"

. "github.com/google/go-containerregistry/pkg/name"
)

var tlsConfiguration = struct {
certifiedRegistries map[string]string
skipTLSVerifyRegistries map[string]struct{}
}{
certifiedRegistries: make(map[string]string),
skipTLSVerifyRegistries: make(map[string]struct{}),
}

func ConfigureTLS(skipTsVerifyRegistries []string, registriesToCertificates map[string]string) {
tlsConfiguration.skipTLSVerifyRegistries = make(map[string]struct{})
for _, registry := range skipTsVerifyRegistries {
tlsConfiguration.skipTLSVerifyRegistries[registry] = struct{}{}
}
tlsConfiguration.certifiedRegistries = make(map[string]string)
for registry := range registriesToCertificates {
tlsConfiguration.certifiedRegistries[registry] = registriesToCertificates[registry]
}
}

func BuildTransport(registry Registry) http.RoundTripper {
var tr http.RoundTripper = newTransport()
if _, present := tlsConfiguration.skipTLSVerifyRegistries[registry.RegistryStr()]; present {
tr.(*http.Transport).TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
} else if certificatePath := tlsConfiguration.certifiedRegistries[registry.RegistryStr()]; certificatePath != "" {
systemCertPool := defaultX509Handler()
if err := appendCertificate(systemCertPool, certificatePath); err != nil {
logrus.WithError(err).Warnf("Failed to load certificate %s for %s\n", certificatePath, registry.RegistryStr())
} else {
tr.(*http.Transport).TLSClientConfig = &tls.Config{
RootCAs: systemCertPool,
}
}
}
return tr
}

// TODO replace it with "http.DefaultTransport.(*http.Transport).Clone()" once in golang 1.12
func newTransport() http.RoundTripper {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

func appendCertificate(pool *x509.CertPool, path string) error {
pem, err := ioutil.ReadFile(path)
if err != nil {
return err
}
pool.AppendCertsFromPEM(pem)
return nil
}

func defaultX509Handler() *x509.CertPool {
systemCertPool, err := x509.SystemCertPool()
if err != nil {
logrus.Warn("Failed to load system cert pool. Loading empty one instead.")
systemCertPool = x509.NewCertPool()
}
return systemCertPool
}