diff --git a/Makefile b/Makefile index d190c1cf..d59caf24 100644 --- a/Makefile +++ b/Makefile @@ -45,10 +45,11 @@ cmd_%: go build $(BUILDFLAGS) -o $(OUTPUT) $(SOURCE) test: ./bin/gocovmerge + rm -f .cover.* go test -coverprofile=.cover.pkg ./... cd lib && go test -coverprofile=../.cover.lib ./... ./bin/gocovmerge .cover.* > .cover - rm .cover.* + rm -f .cover.* go tool cover -html=.cover -o .cover.html ./bin/gocovmerge: diff --git a/conf/weirproxy.yaml b/conf/weirproxy.yaml index a7ea501a..36398a77 100644 --- a/conf/weirproxy.yaml +++ b/conf/weirproxy.yaml @@ -20,6 +20,36 @@ log: max-backups: 1 security: rsa-key-size: 4096 + # tls object is either of type server, client, or peer + # xxxx: + # ca: ca.pem + # cert: c.pem + # key: k.pem + # auto-certs: true + # skip-ca: trure + # client object: + # 1. requires: ca or skip-ca(skip verify server certs) + # 2. optionally: cert/key will be used if server asks + # 3. useless/forbid: auto-certs + # server object: + # 1. requires: cert/key or auto-certs(generate a temporary cert, mostly for testing) + # 2. optionally: ca will enable server-side client verification. + # 3. useless/forbid: skip-ca + # peer object: + # 1. requires: cert/key/ca or auto-certs + # 2. useless/forbid: skip-ca + cluster-tls: # client object + # access to other components like TiDB or PD, will use this + skip-ca: true + sql-tls: # client object + # access to TiDB sql port, it has a standalone TLS configuration + skip-ca: true + server-tls: # server object + # proxy SQL or HTTP port will use this + auto-certs: true + peer-tls: # peer object + # internal communication between proxies + auto-certs: true advance: # ignore-wrong-namespace: true # peer-port: "3081" diff --git a/docker/Dockerfile b/docker/Dockerfile index c3a718fe..1de29978 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -13,6 +13,6 @@ RUN sh ./proxy/apk-fastest-mirror.sh RUN apk add --no-cache --progress git make go ARG BUILDFLAGS ARG GOPROXY -RUN export BUILDFLAGS=${BUILDFLAGS} && export GOPROXY=${GOPROXY} && cd proxy && ls -al && cat Makefile && make cmd && cp bin/* /bin/ && cp -a conf /etc/proxy && cd .. && rm -rf proxy +RUN export BUILDFLAGS=${BUILDFLAGS} && export GOPROXY=${GOPROXY} && cd proxy && make cmd && cp bin/* /bin/ && cp -a conf /etc/proxy && cd .. && rm -rf proxy RUN rm -rf $(go env GOMODCACHE GOCACHE) && apk del git make go ENTRYPOINT ["/bin/weirproxy", "-conf", "/etc/proxy/weirproxy.yaml"] diff --git a/lib/config/namespace.go b/lib/config/namespace.go index 521fa7ae..5fdeb778 100644 --- a/lib/config/namespace.go +++ b/lib/config/namespace.go @@ -22,11 +22,11 @@ type Namespace struct { } type FrontendNamespace struct { - Security TLSCert `yaml:"security" json:"security" toml:"security"` + Security TLSConfig `yaml:"security" json:"security" toml:"security"` } type BackendNamespace struct { - Instances []string `yaml:"instances" json:"instances" toml:"instances"` - SelectorType string `yaml:"selector-type" json:"selector-type" toml:"selector-type"` - Security TLSCert `yaml:"security" json:"security" toml:"security"` + Instances []string `yaml:"instances" json:"instances" toml:"instances"` + SelectorType string `yaml:"selector-type" json:"selector-type" toml:"selector-type"` + Security TLSConfig `yaml:"security" json:"security" toml:"security"` } diff --git a/lib/config/namespace_test.go b/lib/config/namespace_test.go index 003e27bc..3f2aea96 100644 --- a/lib/config/namespace_test.go +++ b/lib/config/namespace_test.go @@ -24,19 +24,21 @@ import ( var testNamespaceConfig = Namespace{ Namespace: "test_ns", Frontend: FrontendNamespace{ - Security: TLSCert{ - CA: "t", - Cert: "t", - Key: "t", + Security: TLSConfig{ + CA: "t", + Cert: "t", + Key: "t", + AutoCerts: true, }, }, Backend: BackendNamespace{ Instances: []string{"127.0.0.1:4000", "127.0.0.1:4001"}, SelectorType: "random", - Security: TLSCert{ - CA: "t", - Cert: "t", - Key: "t", + Security: TLSConfig{ + CA: "t", + Cert: "t", + Key: "t", + SkipCA: true, }, }, } diff --git a/lib/config/proxy.go b/lib/config/proxy.go index d2df9960..e8429b39 100644 --- a/lib/config/proxy.go +++ b/lib/config/proxy.go @@ -75,24 +75,28 @@ type LogFile struct { MaxBackups int `yaml:"max-backups,omitempty" toml:"max-backups,omitempty" json:"max-backups,omitempty"` } -type TLSCert struct { - CA string `yaml:"ca,omitempty" toml:"ca,omitempty" json:"ca,omitempty"` - Cert string `yaml:"cert,omitempty" toml:"cert,omitempty" json:"cert,omitempty"` - Key string `yaml:"key,omitempty" toml:"key,omitempty" json:"key,omitempty"` +type TLSConfig struct { + Cert string `yaml:"cert,omitempty" toml:"cert,omitempty" json:"cert,omitempty"` + Key string `yaml:"key,omitempty" toml:"key,omitempty" json:"key,omitempty"` + AutoCerts bool `yaml:"auto-certs,omitempty" toml:"auto-certs,omitempty" json:"auto-certs,omitempty"` + CA string `yaml:"ca,omitempty" toml:"ca,omitempty" json:"ca,omitempty"` + SkipCA bool `yaml:"skip-ca,omitempty" toml:"skip-ca,omitempty" json:"skip-ca,omitempty"` } -func (c TLSCert) HasCert() bool { +func (c TLSConfig) HasCert() bool { return !(c.Cert == "" && c.Key == "") } -func (c TLSCert) HasCA() bool { +func (c TLSConfig) HasCA() bool { return c.CA != "" } type Security struct { - RSAKeySize int `yaml:"rsa-key-size,omitempty" toml:"rsa-key-size,omitempty" json:"rsa-key-size,omitempty"` - Server TLSCert `yaml:"server,omitempty" toml:"server,omitempty" json:"server,omitempty"` - Cluster TLSCert `yaml:"cluster,omitempty" toml:"cluster,omitempty" json:"cluster,omitempty"` + RSAKeySize int `yaml:"rsa-key-size,omitempty" toml:"rsa-key-size,omitempty" json:"rsa-key-size,omitempty"` + ServerTLS TLSConfig `yaml:"server-tls,omitempty" toml:"server-tls,omitempty" json:"server-tls,omitempty"` + PeerTLS TLSConfig `yaml:"peer-tls,omitempty" toml:"peer-tls,omitempty" json:"peer-tls,omitempty"` + ClusterTLS TLSConfig `yaml:"cluster-tls,omitempty" toml:"cluster-tls,omitempty" json:"cluster-tls,omitempty"` + SQLTLS TLSConfig `yaml:"sql-tls,omitempty" toml:"sql-tls,omitempty" json:"sql-tls,omitempty"` } func NewConfig(data []byte) (*Config, error) { diff --git a/lib/config/proxy_test.go b/lib/config/proxy_test.go index d205e6d5..66cf0182 100644 --- a/lib/config/proxy_test.go +++ b/lib/config/proxy_test.go @@ -57,15 +57,29 @@ var testProxyConfig = Config{ }, Security: Security{ RSAKeySize: 64, - Server: TLSCert{ - CA: "a", - Cert: "b", - Key: "c", + ServerTLS: TLSConfig{ + CA: "a", + Cert: "b", + Key: "c", + AutoCerts: true, }, - Cluster: TLSCert{ - CA: "a", - Cert: "b", - Key: "c", + PeerTLS: TLSConfig{ + CA: "a", + Cert: "b", + Key: "c", + AutoCerts: true, + }, + ClusterTLS: TLSConfig{ + CA: "a", + SkipCA: true, + Cert: "b", + Key: "c", + }, + SQLTLS: TLSConfig{ + CA: "a", + SkipCA: true, + Cert: "b", + Key: "c", }, }, } diff --git a/lib/go.mod b/lib/go.mod index 29cb4ac5..090c9036 100644 --- a/lib/go.mod +++ b/lib/go.mod @@ -3,10 +3,10 @@ module github.com/pingcap/TiProxy/lib go 1.19 require ( - github.com/pingcap/errors v0.11.4 github.com/pingcap/log v1.1.0 github.com/spf13/cobra v1.5.0 github.com/stretchr/testify v1.8.0 + go.etcd.io/etcd/client/pkg/v3 v3.5.4 go.uber.org/atomic v1.9.0 go.uber.org/zap v1.23.0 gopkg.in/yaml.v3 v3.0.1 @@ -17,9 +17,12 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/kr/text v0.2.0 // indirect + github.com/pingcap/errors v0.11.4 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect go.uber.org/multierr v1.7.0 // indirect + golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect ) diff --git a/lib/go.sum b/lib/go.sum index e9b6df37..e5d4c5a6 100644 --- a/lib/go.sum +++ b/lib/go.sum @@ -2,11 +2,13 @@ github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -21,8 +23,9 @@ github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4 github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/log v1.1.0 h1:ELiPxACz7vdo1qAvvaWJg1NrYFoY6gqAh/+Uo6aXdD8= github.com/pingcap/log v1.1.0/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -38,6 +41,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +go.etcd.io/etcd/client/pkg/v3 v3.5.4 h1:lrneYvz923dvC14R54XcA7FXoZ3mlGZAgmwhfm7HqOg= +go.etcd.io/etcd/client/pkg/v3 v3.5.4/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= @@ -46,6 +51,7 @@ go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= +go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY= go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY= @@ -55,6 +61,9 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= diff --git a/lib/util/cmd/encoder.go b/lib/util/cmd/encoder.go index 7212d2b1..8258fa36 100644 --- a/lib/util/cmd/encoder.go +++ b/lib/util/cmd/encoder.go @@ -65,16 +65,12 @@ func (c *tidbEncoder) endQuoteFiled() { func (c *tidbEncoder) encodeError(f zapcore.Field) { err := f.Interface.(error) basic := err.Error() - c.beginQuoteFiled() c.AddString(f.Key, basic) - c.endQuoteFiled() if e, isFormatter := err.(fmt.Formatter); isFormatter { verbose := fmt.Sprintf("%+v", e) if verbose != basic { // This is a rich error type, like those produced by github.com/pkg/errors. - c.beginQuoteFiled() c.AddString(f.Key+"Verbose", verbose) - c.endQuoteFiled() } } } @@ -125,6 +121,9 @@ func (e *tidbEncoder) EncodeEntry(ent zapcore.Entry, fields []zapcore.Field) (*b c.line.AppendByte(' ') } + // append the old fields + c.line.WriteString(e.line.String()) + for _, f := range fields { if f.Type == zapcore.ErrorType { // handle ErrorType in pingcap/log to fix "[key=?,keyVerbose=?]" problem. @@ -132,9 +131,7 @@ func (e *tidbEncoder) EncodeEntry(ent zapcore.Entry, fields []zapcore.Field) (*b c.encodeError(f) continue } - c.beginQuoteFiled() f.AddTo(c) - c.endQuoteFiled() } c.closeOpenNamespaces() @@ -217,103 +214,151 @@ func (s *tidbEncoder) addKey(key string) { s.line.AppendByte('=') } func (s *tidbEncoder) AddArray(key string, arr zapcore.ArrayMarshaler) error { + s.beginQuoteFiled() s.addKey(key) - return s.AppendArray(arr) + err := s.AppendArray(arr) + s.endQuoteFiled() + return err } func (s *tidbEncoder) AddObject(key string, obj zapcore.ObjectMarshaler) error { + s.beginQuoteFiled() s.addKey(key) - return s.AppendObject(obj) + err := s.AppendObject(obj) + s.endQuoteFiled() + return err } func (s *tidbEncoder) AddBinary(key string, val []byte) { s.AddString(key, base64.StdEncoding.EncodeToString(val)) } func (s *tidbEncoder) AddByteString(key string, val []byte) { + s.beginQuoteFiled() s.addKey(key) s.AppendByteString(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddBool(key string, val bool) { + s.beginQuoteFiled() s.addKey(key) s.AppendBool(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddComplex128(key string, val complex128) { + s.beginQuoteFiled() s.addKey(key) s.AppendComplex128(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddComplex64(key string, val complex64) { + s.beginQuoteFiled() s.addKey(key) s.AppendComplex64(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddDuration(key string, val time.Duration) { + s.beginQuoteFiled() s.addKey(key) s.AppendDuration(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddFloat64(key string, val float64) { + s.beginQuoteFiled() s.addKey(key) s.AppendFloat64(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddFloat32(key string, val float32) { + s.beginQuoteFiled() s.addKey(key) s.AppendFloat32(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddInt(key string, val int) { + s.beginQuoteFiled() s.addKey(key) s.AppendInt(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddInt8(key string, val int8) { + s.beginQuoteFiled() s.addKey(key) s.AppendInt8(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddInt16(key string, val int16) { + s.beginQuoteFiled() s.addKey(key) s.AppendInt16(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddInt32(key string, val int32) { + s.beginQuoteFiled() s.addKey(key) s.AppendInt32(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddInt64(key string, val int64) { + s.beginQuoteFiled() s.addKey(key) s.AppendInt64(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddString(key string, val string) { + s.beginQuoteFiled() s.addKey(key) s.AppendString(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddTime(key string, val time.Time) { + s.beginQuoteFiled() s.addKey(key) s.AppendTime(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddUint(key string, val uint) { + s.beginQuoteFiled() s.addKey(key) s.AppendUint(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddUint8(key string, val uint8) { + s.beginQuoteFiled() s.addKey(key) s.AppendUint8(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddUint16(key string, val uint16) { + s.beginQuoteFiled() s.addKey(key) s.AppendUint16(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddUint32(key string, val uint32) { + s.beginQuoteFiled() s.addKey(key) s.AppendUint32(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddUint64(key string, val uint64) { + s.beginQuoteFiled() s.addKey(key) s.AppendUint64(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddUintptr(key string, val uintptr) { + s.beginQuoteFiled() s.addKey(key) s.AppendUintptr(val) + s.endQuoteFiled() } func (s *tidbEncoder) AddReflected(key string, obj interface{}) error { + s.beginQuoteFiled() s.addKey(key) enc := json.NewEncoder(s.line) if err := enc.Encode(obj); err != nil { return err } s.line.TrimNewline() + s.endQuoteFiled() return nil } func (s *tidbEncoder) OpenNamespace(key string) { diff --git a/lib/util/security/tls.go b/lib/util/security/tls.go index 437c0c62..60ad1bc0 100644 --- a/lib/util/security/tls.go +++ b/lib/util/security/tls.go @@ -22,186 +22,85 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "io/ioutil" "math/big" "net" "os" "path/filepath" "time" - "github.com/pingcap/errors" + "github.com/pingcap/TiProxy/lib/config" + "github.com/pingcap/TiProxy/lib/util/errors" + "go.etcd.io/etcd/client/pkg/v3/transport" "go.uber.org/zap" ) -// CreateServerTLSConfig creates a tlsConfig that is used to connect to the client. -func CreateServerTLSConfig(logger *zap.Logger, ca, key, cert string, rsaKeySize int, workdir string) (tlsConfig *tls.Config, err error) { - if len(cert) == 0 || len(key) == 0 { - cert = filepath.Join(workdir, "cert.pem") - key = filepath.Join(workdir, "key.pem") - if err := createTLSCertificates(logger, cert, key, rsaKeySize); err != nil { - return nil, err - } - } +func createTLSConfigificates(logger *zap.Logger, certpath, keypath, capath string, rsaKeySize int) error { + logger = logger.With(zap.String("cert", certpath), zap.String("key", keypath), zap.String("ca", capath), zap.Int("rsaKeySize", rsaKeySize)) - var tlsCert tls.Certificate - tlsCert, err = tls.LoadX509KeyPair(cert, key) - if err != nil { - logger.Warn("load x509 failed", zap.Error(err)) - err = errors.Trace(err) - return + _, e1 := os.Stat(certpath) + _, e2 := os.Stat(keypath) + if errors.Is(e1, os.ErrExist) || errors.Is(e2, os.ErrExist) { + logger.Warn("either cert or key exists") + return nil } - // Try loading CA cert. - clientAuthPolicy := tls.NoClientCert - var certPool *x509.CertPool - if len(ca) > 0 { - var caCert []byte - caCert, err = os.ReadFile(ca) - if err != nil { - logger.Warn("read file failed", zap.Error(err)) - err = errors.Trace(err) - return - } - certPool = x509.NewCertPool() - if certPool.AppendCertsFromPEM(caCert) { - clientAuthPolicy = tls.VerifyClientCertIfGiven + if capath != "" { + _, e3 := os.Stat(capath) + if errors.Is(e3, os.ErrExist) { + logger.Warn("ca exists") + return nil } } - tlsConfig = &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - ClientCAs: certPool, - ClientAuth: clientAuthPolicy, - } - return -} - -func createTLSCertificates(logger *zap.Logger, certpath string, keypath string, rsaKeySize int) error { - privkey, err := rsa.GenerateKey(rand.Reader, rsaKeySize) - if err != nil { - return err - } - - certValidity := 90 * 24 * time.Hour // 90 days - notBefore := time.Now() - notAfter := notBefore.Add(certValidity) - hostname, err := os.Hostname() - if err != nil { + if err := os.MkdirAll(filepath.Dir(keypath), 0755); err != nil { return err } - - template := x509.Certificate{ - Subject: pkix.Name{ - CommonName: "TiDB_Server_Auto_Generated_Server_Certificate", - }, - SerialNumber: big.NewInt(1), - NotBefore: notBefore, - NotAfter: notAfter, - DNSNames: []string{hostname}, - } - - // DER: Distinguished Encoding Rules, this is the ASN.1 encoding rule of the certificate. - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privkey.PublicKey, privkey) - if err != nil { - return err - } - - certOut, err := os.Create(certpath) - if err != nil { + if err := os.MkdirAll(filepath.Dir(certpath), 0755); err != nil { return err } - if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { - return err - } - if err := certOut.Close(); err != nil { - return err + if capath != "" { + if err := os.MkdirAll(filepath.Dir(capath), 0755); err != nil { + return err + } } - keyOut, err := os.OpenFile(keypath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + certPEM, keyPEM, caPEM, err := CreateTempTLS() if err != nil { return err } - privBytes, err := x509.MarshalPKCS8PrivateKey(privkey) - if err != nil { + if err := ioutil.WriteFile(certpath, certPEM.Bytes(), 0600); err != nil { return err } - - if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { + if err := ioutil.WriteFile(keypath, keyPEM.Bytes(), 0600); err != nil { return err } - - if err := keyOut.Close(); err != nil { - return err + if capath != "" { + if err := ioutil.WriteFile(capath, caPEM.Bytes(), 0600); err != nil { + return err + } } - logger.Info("TLS Certificates created", zap.String("cert", certpath), zap.String("key", keypath), - zap.Duration("validity", certValidity), zap.Int("rsaKeySize", rsaKeySize)) + logger.Info("TLS Certificates created") return nil } -// CreateClusterTLSConfig generates tls's config based on security section of the config. -// It's used to connect to PD. -func CreateClusterTLSConfig(sslCA, sslKey, sslCert string) (tlsConfig *tls.Config, err error) { - if len(sslCA) != 0 { - certPool := x509.NewCertPool() - // Create a certificate pool from the certificate authority - var ca []byte - ca, err = os.ReadFile(sslCA) - if err != nil { - err = errors.Errorf("could not read ca certificate: %s", err) - return - } - // Append the certificates from the CA - if !certPool.AppendCertsFromPEM(ca) { - err = errors.New("failed to append ca certs") - return - } - tlsConfig = &tls.Config{ - RootCAs: certPool, - ClientCAs: certPool, +func AutoTLS(logger *zap.Logger, scfg *config.TLSConfig, autoca bool, workdir, mod string, keySize int) error { + if !scfg.HasCert() && scfg.AutoCerts { + scfg.Cert = filepath.Join(workdir, mod, "cert.pem") + scfg.Key = filepath.Join(workdir, mod, "key.pem") + if autoca { + scfg.CA = filepath.Join(workdir, mod, "ca.pem") } - - if len(sslCert) != 0 && len(sslKey) != 0 { - getCert := func() (*tls.Certificate, error) { - // Load the client certificates from disk - cert, err := tls.LoadX509KeyPair(sslCert, sslKey) - if err != nil { - return nil, errors.Errorf("could not load client key pair: %s", err) - } - return &cert, nil - } - // pre-test cert's loading. - if _, err = getCert(); err != nil { - return - } - tlsConfig.GetClientCertificate = func(info *tls.CertificateRequestInfo) (certificate *tls.Certificate, err error) { - return getCert() - } - tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (certificate *tls.Certificate, err error) { - return getCert() - } + if err := createTLSConfigificates(logger, scfg.Cert, scfg.Key, scfg.CA, keySize); err != nil { + return errors.WithStack(err) } } - return -} - -// CreateClientTLSConfig creates a tlsConfig that is used to connect to the backend server. -func CreateClientTLSConfig(sslCA, sslKey, sslCert string) (tlsConfig *tls.Config, err error) { - tlsConfig, err = CreateClusterTLSConfig(sslCA, sslKey, sslCert) - if err != nil { - return nil, err - } - if tlsConfig != nil { - return tlsConfig, nil - } - tlsConfig = &tls.Config{ - InsecureSkipVerify: true, - } - return + return nil } -// CreateTLSConfigForTest is from https://gist.github.com/shaneutt/5e1995295cff6721c89a71d13a71c251. -func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) { +func CreateTempTLS() (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) { // set up our CA certificate ca := &x509.Certificate{ SerialNumber: big.NewInt(2019), @@ -224,27 +123,23 @@ func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Con // create our private and public key caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { - return nil, nil, err + return nil, nil, nil, err } // create the CA caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) if err != nil { - return nil, nil, err + return nil, nil, nil, err } // pem encode caPEM := new(bytes.Buffer) - pem.Encode(caPEM, &pem.Block{ + if err := pem.Encode(caPEM, &pem.Block{ Type: "CERTIFICATE", Bytes: caBytes, - }) - - caPrivKeyPEM := new(bytes.Buffer) - pem.Encode(caPrivKeyPEM, &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey), - }) + }); err != nil { + return nil, nil, nil, err + } // set up our server certificate cert := &x509.Certificate{ @@ -267,29 +162,45 @@ func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Con certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { - return nil, nil, err + return nil, nil, nil, err } certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey) if err != nil { - return nil, nil, err + return nil, nil, nil, err } certPEM := new(bytes.Buffer) - pem.Encode(certPEM, &pem.Block{ + if err := pem.Encode(certPEM, &pem.Block{ Type: "CERTIFICATE", Bytes: certBytes, - }) + }); err != nil { + return nil, nil, nil, err + } - certPrivKeyPEM := new(bytes.Buffer) - pem.Encode(certPrivKeyPEM, &pem.Block{ + keyPEM := new(bytes.Buffer) + if err := pem.Encode(keyPEM, &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), - }) + }); err != nil { + return nil, nil, nil, err + } - serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes()) - if err != nil { - return nil, nil, err + return certPEM, keyPEM, caPEM, nil +} + +// CreateTLSConfigForTest is from https://gist.github.com/shaneutt/5e1995295cff6721c89a71d13a71c251. +func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) { + certPEM, keyPEM, caPEM, uerr := CreateTempTLS() + if uerr != nil { + err = uerr + return + } + + serverCert, uerr := tls.X509KeyPair(certPEM.Bytes(), keyPEM.Bytes()) + if uerr != nil { + err = uerr + return } serverTLSConf = &tls.Config{ @@ -299,8 +210,108 @@ func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Con certpool := x509.NewCertPool() certpool.AppendCertsFromPEM(caPEM.Bytes()) clientTLSConf = &tls.Config{ + InsecureSkipVerify: true, RootCAs: certpool, } return } + +func BuildServerTLSConfig(logger *zap.Logger, cfg config.TLSConfig) (*tls.Config, error) { + logger = logger.With(zap.String("tls", "server")) + if !cfg.HasCert() { + logger.Warn("require certificates to secure clients connections, disable TLS") + return nil, nil + } + + tcfg := &tls.Config{} + cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key) + if err != nil { + return nil, errors.Errorf("failed to load certs: %w", err) + } + tcfg.Certificates = append(tcfg.Certificates, cert) + + if !cfg.HasCA() { + logger.Warn("no CA, server will not authenticate clients (connection is still secured)") + return tcfg, nil + } + + tcfg.ClientAuth = tls.RequireAndVerifyClientCert + tcfg.ClientCAs = x509.NewCertPool() + certBytes, err := ioutil.ReadFile(cfg.CA) + if err != nil { + return nil, errors.Errorf("failed to read CA: %w", err) + } + if !tcfg.ClientCAs.AppendCertsFromPEM(certBytes) { + return nil, errors.Errorf("failed to append CA") + } + return tcfg, nil +} + +func BuildClientTLSConfig(logger *zap.Logger, cfg config.TLSConfig) (*tls.Config, error) { + logger = logger.With(zap.String("tls", "client")) + if !cfg.HasCA() { + if cfg.SkipCA { + // still enable TLS without verify server certs + return &tls.Config{InsecureSkipVerify: true}, nil + } + logger.Warn("no CA to verify server connections, disable TLS") + return nil, nil + } + + tcfg := &tls.Config{} + tcfg.RootCAs = x509.NewCertPool() + certBytes, err := ioutil.ReadFile(cfg.CA) + if err != nil { + return nil, errors.Errorf("failed to read CA: %w", err) + } + if !tcfg.RootCAs.AppendCertsFromPEM(certBytes) { + return nil, errors.Errorf("failed to append CA") + } + + if !cfg.HasCert() { + logger.Warn("no certificates, server may reject the connection") + return tcfg, nil + } + cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key) + if err != nil { + return nil, errors.Errorf("failed to load certs for: %w", err) + } + tcfg.Certificates = append(tcfg.Certificates, cert) + + return tcfg, nil +} + +func BuildEtcdTLSConfig(logger *zap.Logger, server, peer config.TLSConfig) (clientInfo, peerInfo transport.TLSInfo, err error) { + logger = logger.With(zap.String("tls", "etcd")) + clientInfo.Logger = logger + peerInfo.Logger = logger + + if server.HasCert() { + clientInfo.CertFile = server.Cert + clientInfo.KeyFile = server.Key + if server.HasCA() { + clientInfo.TrustedCAFile = server.CA + clientInfo.ClientCertAuth = true + } else if !server.SkipCA { + logger.Warn("no CA, proxy will not authenticate etcd clients (connection is still secured)") + } + } + + if peer.HasCert() { + peerInfo.CertFile = peer.Cert + peerInfo.KeyFile = peer.Key + if peer.HasCA() { + peerInfo.TrustedCAFile = peer.CA + peerInfo.ClientCertAuth = true + } else if peer.SkipCA { + peerInfo.InsecureSkipVerify = true + peerInfo.ClientCertAuth = false + } else { + err = errors.New("need a full set of cert/key/ca or cert/key/skip-ca to secure etcd peer inter-communication") + return + } + } + + return +} diff --git a/pkg/manager/config/manager.go b/pkg/manager/config/manager.go index 112d8c55..b6692e73 100644 --- a/pkg/manager/config/manager.go +++ b/pkg/manager/config/manager.go @@ -24,12 +24,13 @@ import ( "github.com/pingcap/TiProxy/lib/util/waitgroup" "go.etcd.io/etcd/api/v3/mvccpb" clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/server/v3/lease" + "go.etcd.io/etcd/server/v3/mvcc" "go.uber.org/zap" ) const ( DefaultEtcdDialTimeout = 3 * time.Second - DefaultWatchInterval = 10 * time.Minute DefaultEtcdPath = "/config" PathPrefixNamespace = "ns" @@ -41,12 +42,11 @@ var ( ) type ConfigManager struct { - wg waitgroup.WaitGroup - cancel context.CancelFunc - logger *zap.Logger - etcdClient *clientv3.Client - kv clientv3.KV - basePath string + wg waitgroup.WaitGroup + cancel context.CancelFunc + logger *zap.Logger + kv mvcc.WatchableKV + basePath string // config ignoreWrongNamespace bool @@ -61,32 +61,13 @@ func NewConfigManager() *ConfigManager { } } -func (srv *ConfigManager) Init(ctx context.Context, addrs []string, cfg config.Advance, logger *zap.Logger) error { +func (srv *ConfigManager) Init(ctx context.Context, kv mvcc.WatchableKV, cfg config.Advance, logger *zap.Logger) error { srv.logger = logger srv.ignoreWrongNamespace = cfg.IgnoreWrongNamespace - if cfg.WatchInterval == "" { - srv.watchInterval = DefaultWatchInterval - } else { - wi, err := time.ParseDuration(cfg.WatchInterval) - if err != nil { - return errors.Wrapf(err, "failed to parse watch interval %s", cfg.WatchInterval) - } - srv.watchInterval = wi - } // slash appended to distinguish '/dir'(file) and '/dir/'(directory) srv.basePath = appendSlashToDirPath(DefaultEtcdPath) - etcdConfig := clientv3.Config{ - Endpoints: addrs, - DialTimeout: DefaultEtcdDialTimeout, - } - - etcdClient, err := clientv3.New(etcdConfig) - if err != nil { - return errors.Wrapf(err, "create etcd config center error") - } - srv.etcdClient = etcdClient - srv.kv = clientv3.NewKV(srv.etcdClient) + srv.kv = kv ctx, cancel := context.WithCancel(ctx) srv.cancel = cancel @@ -96,103 +77,85 @@ func (srv *ConfigManager) Init(ctx context.Context, addrs []string, cfg config.A return nil } -func (e *ConfigManager) watch(ctx context.Context, ns, key string, f func(*zap.Logger, *clientv3.Event)) { - wkey := path.Join(e.basePath, ns, key) - logger := e.logger.With(zap.String("component", wkey)) +func (e *ConfigManager) watch(ctx context.Context, ns, key string, f func(*zap.Logger, mvccpb.Event)) { + wkey := []byte(path.Join(e.basePath, ns, key)) + logger := e.logger.With(zap.String("component", string(wkey))) retryInterval := 5 * time.Second e.wg.Run(func() { - var prevKV *mvccpb.KeyValue - - ticker := time.NewTicker(e.watchInterval) - defer ticker.Stop() - - wch := e.etcdClient.Watch(ctx, wkey) + wch := e.kv.NewWatchStream() + defer wch.Close() for { + if _, err := wch.Watch(mvcc.AutoWatchID, wkey, getPrefix(wkey), wch.Rev()-1); err == nil { + break + } + if k := retryInterval * 2; k < e.watchInterval { + retryInterval = k + } + logger.Warn("failed to watch, will try again later", zap.Duration("sleep", retryInterval)) select { case <-ctx.Done(): return - case <-ticker.C: - resp, err := e.kv.Get(ctx, wkey) - if err != nil { - logger.Warn("failed to poll", zap.Error(err)) - break - } - // len == 0 may mean there is no value set yet, do not warn about that - if len(resp.Kvs) > 1 { - logger.Warn("failed to poll", zap.Error(ErrNoOrMultiResults)) - break - } else if len(resp.Kvs) == 1 { - f(logger, &clientv3.Event{ - Type: mvccpb.PUT, - Kv: resp.Kvs[0], - PrevKv: prevKV, - }) - prevKV = resp.Kvs[0] - } - case res := <-wch: - if res.Canceled { - // don't wait for more than the polling interval - if k := retryInterval * 2; k < e.watchInterval { - retryInterval = k - } - logger.Warn("failed to watch, will try again later", zap.Error(res.Err()), zap.Duration("sleep", retryInterval)) - time.Sleep(retryInterval) - wch = e.etcdClient.Watch(ctx, wkey, clientv3.WithCreatedNotify()) - break - } - + case <-time.After(retryInterval): + } + } + for { + select { + case <-ctx.Done(): + return + case res := <-wch.Chan(): for _, evt := range res.Events { f(logger, evt) - prevKV = evt.Kv } - - // reset the ticker to prevent another tick immediately - ticker.Reset(e.watchInterval) } } }) } func (e *ConfigManager) get(ctx context.Context, ns, key string) (*mvccpb.KeyValue, error) { - resp, err := e.kv.Get(ctx, path.Join(e.basePath, ns, key)) + resp, err := e.kv.Range(ctx, []byte(path.Join(e.basePath, ns, key)), nil, mvcc.RangeOptions{Rev: 0}) if err != nil { return nil, err } - if len(resp.Kvs) != 1 { + if len(resp.KVs) != 1 { return nil, ErrNoOrMultiResults } - return resp.Kvs[0], nil + return &resp.KVs[0], nil } -func (e *ConfigManager) list(ctx context.Context, ns string, ops ...clientv3.OpOption) ([]*mvccpb.KeyValue, error) { - options := make([]clientv3.OpOption, 1, 1+len(ops)) - options[0] = clientv3.WithPrefix() - options = append(options, ops...) - resp, err := e.kv.Get(ctx, path.Join(e.basePath, ns), options...) - if err != nil { - return nil, err +func getPrefix(key []byte) []byte { + end := make([]byte, len(key)) + copy(end, key) + for i := len(end) - 1; i >= 0; i-- { + if end[i] < 0xff { + end[i] = end[i] + 1 + end = end[:i+1] + return end + } } - return resp.Kvs, nil + return []byte{0} } -func (e *ConfigManager) set(ctx context.Context, ns, key, val string) (*mvccpb.KeyValue, error) { - resp, err := e.kv.Put(ctx, path.Join(e.basePath, ns, key), val) +func (e *ConfigManager) list(ctx context.Context, ns string, ops ...clientv3.OpOption) ([]mvccpb.KeyValue, error) { + k := []byte(path.Join(e.basePath, ns)) + resp, err := e.kv.Range(ctx, k, getPrefix(k), mvcc.RangeOptions{Rev: 0}) if err != nil { return nil, err } - return resp.PrevKv, nil + return resp.KVs, nil +} + +func (e *ConfigManager) set(ctx context.Context, ns, key, val string) error { + _ = e.kv.Put([]byte(path.Join(e.basePath, ns, key)), []byte(val), lease.NoLease) + return nil } func (e *ConfigManager) del(ctx context.Context, ns, key string) error { - _, err := e.kv.Delete(ctx, path.Join(e.basePath, ns, key)) - if err != nil { - return err - } + _, _ = e.kv.DeleteRange([]byte(path.Join(e.basePath, ns, key)), nil) return nil } func (e *ConfigManager) Close() error { e.cancel() e.wg.Wait() - return errors.Wrapf(e.etcdClient.Close(), "fail to close config manager") + return nil } diff --git a/pkg/manager/config/manager_test.go b/pkg/manager/config/manager_test.go index deca5c06..fd99f794 100644 --- a/pkg/manager/config/manager_test.go +++ b/pkg/manager/config/manager_test.go @@ -21,11 +21,13 @@ import ( "path" "path/filepath" "testing" + "time" "github.com/pingcap/TiProxy/lib/config" "github.com/pingcap/TiProxy/lib/util/logger" "github.com/pingcap/TiProxy/lib/util/waitgroup" "github.com/stretchr/testify/require" + "go.etcd.io/etcd/api/v3/mvccpb" clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/server/v3/embed" "go.uber.org/zap" @@ -37,28 +39,23 @@ func testConfigManager(t *testing.T, cfg config.Advance) (*ConfigManager, contex testDir := t.TempDir() - log := logger.CreateLoggerForTest(t) + logger := logger.CreateLoggerForTest(t) etcd_cfg := embed.NewConfig() etcd_cfg.LCUrls = []url.URL{*addr} etcd_cfg.LPUrls = []url.URL{*addr} etcd_cfg.Dir = filepath.Join(testDir, "etcd") - etcd_cfg.ZapLoggerBuilder = embed.NewZapLoggerBuilder(log.Named("etcd")) + etcd_cfg.ZapLoggerBuilder = embed.NewZapLoggerBuilder(logger.Named("etcd")) etcd, err := embed.StartEtcd(etcd_cfg) require.NoError(t, err) - ends := make([]string, len(etcd.Clients)) - for i := range ends { - ends[i] = etcd.Clients[i].Addr().String() - } - ctx, cancel := context.WithCancel(context.Background()) if ddl, ok := t.Deadline(); ok { ctx, cancel = context.WithDeadline(ctx, ddl) } cfgmgr := NewConfigManager() - require.NoError(t, cfgmgr.Init(ctx, ends, cfg, log)) + require.NoError(t, cfgmgr.Init(ctx, etcd.Server.KV(), cfg, logger)) t.Cleanup(func() { require.NoError(t, cfgmgr.Close()) @@ -89,8 +86,7 @@ func TestBase(t *testing.T) { ns := getNs(i) for j := 0; j < valNum; j++ { k := getKey(j) - _, err := cfgmgr.set(ctx, ns, k, k) - require.NoError(t, err) + require.NoError(t, cfgmgr.set(ctx, ns, k, k)) } } @@ -123,8 +119,7 @@ func TestBase(t *testing.T) { ns := getNs(i) for j := 0; j < valNum; j++ { k := getKey(j) - _, err := cfgmgr.set(ctx, ns, k, k) - require.NoError(t, err) + require.NoError(t, cfgmgr.set(ctx, ns, k, k)) require.NoError(t, cfgmgr.del(ctx, ns, k)) } @@ -144,8 +139,7 @@ func TestBaseConcurrency(t *testing.T) { for i := 0; i < batchNum; i++ { k := fmt.Sprint(i) wg.Run(func() { - _, err := cfgmgr.set(ctx, k, "1", "1") - require.NoError(t, err) + require.NoError(t, cfgmgr.set(ctx, k, "1", "1")) }) wg.Run(func() { @@ -158,16 +152,14 @@ func TestBaseConcurrency(t *testing.T) { for i := 0; i < batchNum; i++ { k := fmt.Sprint(i) - _, err := cfgmgr.set(ctx, k, "1", "1") - require.NoError(t, err) + require.NoError(t, cfgmgr.set(ctx, k, "1", "1")) } for i := 0; i < batchNum; i++ { k := fmt.Sprint(i) wg.Run(func() { - _, err := cfgmgr.set(ctx, k, "1", "1") - require.NoError(t, err) + require.NoError(t, cfgmgr.set(ctx, k, "1", "1")) }) wg.Run(func() { @@ -181,11 +173,10 @@ func TestBaseConcurrency(t *testing.T) { func TestBaseWatch(t *testing.T) { cfgmgr, ctx := testConfigManager(t, config.Advance{ IgnoreWrongNamespace: true, - WatchInterval: "1s", }) ch := make(chan string, 1) - cfgmgr.watch(ctx, "test", "t", func(l *zap.Logger, e *clientv3.Event) { + cfgmgr.watch(ctx, "test", "t", func(l *zap.Logger, e mvccpb.Event) { ch <- string(e.Kv.Value) }) @@ -195,23 +186,13 @@ func TestBaseWatch(t *testing.T) { } // set it - _, err := cfgmgr.set(ctx, "test", "t", "1") - require.NoError(t, err) - - // check multiple times, it will become the value after some point for at least three times - count := 0 - for i := 0; i < 10; i++ { - val := <-ch - if val == "1" { - count++ - } else if count != 0 { - t.Fatal("watched value changed after setting it to 1") - } - if count == 3 { - break - } - } - if count < 3 { - t.Fatal("should met the same value at least two times, one from polling, one from notify, one from created") + require.NoError(t, cfgmgr.set(ctx, "test", "t", "1")) + + // now the only way to check watch is to wait + select { + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting chan") + case tg := <-ch: + require.Equal(t, "1", tg) } } diff --git a/pkg/manager/config/namespace.go b/pkg/manager/config/namespace.go index 2b6a0d53..54db7712 100644 --- a/pkg/manager/config/namespace.go +++ b/pkg/manager/config/namespace.go @@ -64,8 +64,7 @@ func (e *ConfigManager) SetNamespace(ctx context.Context, ns string, nsc *config if err != nil { return err } - _, err = e.set(ctx, PathPrefixNamespace, ns, string(r)) - return err + return e.set(ctx, PathPrefixNamespace, ns, string(r)) } func (e *ConfigManager) DelNamespace(ctx context.Context, ns string) error { diff --git a/pkg/manager/config/proxy.go b/pkg/manager/config/proxy.go index 6730d214..59a3ffd3 100644 --- a/pkg/manager/config/proxy.go +++ b/pkg/manager/config/proxy.go @@ -19,12 +19,12 @@ import ( "encoding/json" "github.com/pingcap/TiProxy/lib/config" - clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/api/v3/mvccpb" "go.uber.org/zap" ) func (e *ConfigManager) initProxyConfig(ctx context.Context) { - e.watch(ctx, PathPrefixProxy, "config", func(logger *zap.Logger, evt *clientv3.Event) { + e.watch(ctx, PathPrefixProxy, "config", func(logger *zap.Logger, evt mvccpb.Event) { var proxy config.ProxyServerOnline if err := json.Unmarshal(evt.Kv.Value, &proxy); err != nil { logger.Warn("failed unmarshal proxy config", zap.Error(err)) @@ -43,6 +43,5 @@ func (e *ConfigManager) SetProxyConfig(ctx context.Context, proxy *config.ProxyS if err != nil { return err } - _, err = e.set(ctx, PathPrefixProxy, "config", string(value)) - return err + return e.set(ctx, PathPrefixProxy, "config", string(value)) } diff --git a/pkg/manager/config/proxy_test.go b/pkg/manager/config/proxy_test.go index ca4f08ba..b0615215 100644 --- a/pkg/manager/config/proxy_test.go +++ b/pkg/manager/config/proxy_test.go @@ -45,6 +45,7 @@ func TestProxyConfig(t *testing.T) { TCPKeepAlive: true, }, } + ch := cfgmgr.GetProxyConfig() for _, tc := range cases { require.NoError(t, cfgmgr.SetProxyConfig(ctx, tc)) diff --git a/pkg/manager/namespace/manager.go b/pkg/manager/namespace/manager.go index ea26ac69..ff8dfa08 100644 --- a/pkg/manager/namespace/manager.go +++ b/pkg/manager/namespace/manager.go @@ -16,32 +16,33 @@ package namespace import ( - "crypto/tls" - "crypto/x509" "fmt" - "io/ioutil" + "net/http" "sync" "github.com/pingcap/TiProxy/lib/config" - "github.com/pingcap/TiProxy/pkg/manager/router" "github.com/pingcap/TiProxy/lib/util/errors" + "github.com/pingcap/TiProxy/lib/util/security" + "github.com/pingcap/TiProxy/pkg/manager/router" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" ) type NamespaceManager struct { sync.RWMutex - client *clientv3.Client - logger *zap.Logger - nsm map[string]*Namespace + client *clientv3.Client + httpCli *http.Client + logger *zap.Logger + nsm map[string]*Namespace } func NewNamespaceManager() *NamespaceManager { return &NamespaceManager{} } -func (mgr *NamespaceManager) buildNamespace(cfg *config.Namespace, client *clientv3.Client) (*Namespace, error) { +func (mgr *NamespaceManager) buildNamespace(cfg *config.Namespace) (*Namespace, error) { logger := mgr.logger.With(zap.String("namespace", cfg.Namespace)) - rt, err := router.NewScoreBasedRouter(&cfg.Backend, client) + + rt, err := router.NewScoreBasedRouter(&cfg.Backend, mgr.client, mgr.httpCli) if err != nil { return nil, errors.Errorf("build router error: %w", err) } @@ -50,62 +51,14 @@ func (mgr *NamespaceManager) buildNamespace(cfg *config.Namespace, client *clien router: rt, } - // frontend tls configuration - { - r.frontendTLS = &tls.Config{} - - if !cfg.Frontend.Security.HasCert() { - // TODO: require certs here - logger.Warn("require certificates to secure frontend tls connections") - } else { - cert, err := tls.LoadX509KeyPair(cfg.Frontend.Security.Cert, cfg.Frontend.Security.Key) - if err != nil { - return nil, errors.Errorf("failed to load server certs: %w", err) - } - r.frontendTLS.Certificates = append(r.frontendTLS.Certificates, cert) - } - - if cfg.Frontend.Security.HasCA() { - r.frontendTLS.ClientAuth = tls.RequireAndVerifyClientCert - r.frontendTLS.ClientCAs = x509.NewCertPool() - certBytes, err := ioutil.ReadFile(cfg.Frontend.Security.CA) - if err != nil { - return nil, errors.Errorf("failed to read server signed certs from disk: %w", err) - } - if !r.frontendTLS.ClientCAs.AppendCertsFromPEM(certBytes) { - return nil, errors.Errorf("failed to load server signed certs") - } - } else { - logger.Warn("no signed certs for frontend, proxy will not authenticate clients (connection is still secured)") - } + r.frontendTLS, err = security.BuildServerTLSConfig(logger, cfg.Frontend.Security) + if err != nil { + return nil, errors.Errorf("build frontend TLS error: %w", err) } - { - r.backendTLS = &tls.Config{} - // backend tls configuration - if !cfg.Backend.Security.HasCA() { - // TODO: require certs here - logger.Error("require signed certs to verify backend tls connections") - } else { - r.backendTLS.RootCAs = x509.NewCertPool() - certBytes, err := ioutil.ReadFile(cfg.Backend.Security.CA) - if err != nil { - return nil, errors.Errorf("failed to read server signed certs from disk: %w", err) - } - if !r.backendTLS.RootCAs.AppendCertsFromPEM(certBytes) { - return nil, errors.Errorf("failed to load server signed certs") - } - } - - if cfg.Backend.Security.HasCert() { - cert, err := tls.LoadX509KeyPair(cfg.Backend.Security.Cert, cfg.Backend.Security.Key) - if err != nil { - return nil, errors.Errorf("failed to load cluster certs: %w", err) - } - r.backendTLS.Certificates = append(r.backendTLS.Certificates, cert) - } else { - logger.Warn("no certs for backend authentication, backend may reject proxy connections (connection is still secured)") - } + r.backendTLS, err = security.BuildClientTLSConfig(logger, cfg.Backend.Security) + if err != nil { + return nil, errors.Errorf("build backend TLS error: %w", err) } return r, nil @@ -125,7 +78,7 @@ func (mgr *NamespaceManager) CommitNamespaces(nss []*config.Namespace, nss_delet continue } - ns, err := mgr.buildNamespace(nsc, mgr.client) + ns, err := mgr.buildNamespace(nsc) if err != nil { return fmt.Errorf("%w: create namespace error, namespace: %s", err, nsc.Namespace) } @@ -138,9 +91,10 @@ func (mgr *NamespaceManager) CommitNamespaces(nss []*config.Namespace, nss_delet return nil } -func (mgr *NamespaceManager) Init(logger *zap.Logger, nss []*config.Namespace, client *clientv3.Client) error { +func (mgr *NamespaceManager) Init(logger *zap.Logger, nss []*config.Namespace, client *clientv3.Client, httpCli *http.Client) error { mgr.Lock() mgr.client = client + mgr.httpCli = httpCli mgr.logger = logger mgr.Unlock() diff --git a/pkg/manager/router/backend_observer.go b/pkg/manager/router/backend_observer.go index 163f3977..ee0dcc61 100644 --- a/pkg/manager/router/backend_observer.go +++ b/pkg/manager/router/backend_observer.go @@ -132,6 +132,8 @@ type BackendObserver struct { // All the backend info in the topology, including tombstones. allBackendInfo map[string]*BackendInfo client *clientv3.Client + httpCli *http.Client + httpTLS bool staticAddrs []string eventReceiver BackendEventReceiver wg waitgroup.WaitGroup @@ -139,16 +141,15 @@ type BackendObserver struct { } // InitEtcdClient initializes an etcd client that fetches TiDB instance topology from PD. -func InitEtcdClient(cfg *config.Config) (*clientv3.Client, error) { +func InitEtcdClient(logger *zap.Logger, cfg *config.Config) (*clientv3.Client, error) { pdAddr := cfg.Proxy.PDAddrs if len(pdAddr) == 0 { // use tidb server addresses directly return nil, nil } pdEndpoints := strings.Split(pdAddr, ",") - logConfig := zap.NewProductionConfig() - logConfig.Level = zap.NewAtomicLevelAt(zap.ErrorLevel) - tlsConfig, err := security.CreateClusterTLSConfig(cfg.Security.Cluster.CA, cfg.Security.Cluster.Key, cfg.Security.Cluster.Cert) + logger.Info("connect PD servers", zap.Strings("addrs", pdEndpoints)) + tlsConfig, err := security.BuildClientTLSConfig(logger, cfg.Security.ClusterTLS) if err != nil { return nil, err } @@ -156,7 +157,7 @@ func InitEtcdClient(cfg *config.Config) (*clientv3.Client, error) { etcdClient, err = clientv3.New(clientv3.Config{ Endpoints: pdEndpoints, TLS: tlsConfig, - LogConfig: &logConfig, + Logger: logger.Named("etcdcli"), AutoSyncInterval: 30 * time.Second, DialTimeout: 5 * time.Second, DialOptions: []grpc.DialOption{ @@ -181,8 +182,8 @@ func InitEtcdClient(cfg *config.Config) (*clientv3.Client, error) { } // StartBackendObserver creates a BackendObserver and starts watching. -func StartBackendObserver(eventReceiver BackendEventReceiver, client *clientv3.Client, config *HealthCheckConfig, staticAddrs []string) (*BackendObserver, error) { - bo, err := NewBackendObserver(eventReceiver, client, config, staticAddrs) +func StartBackendObserver(eventReceiver BackendEventReceiver, client *clientv3.Client, httpCli *http.Client, config *HealthCheckConfig, staticAddrs []string) (*BackendObserver, error) { + bo, err := NewBackendObserver(eventReceiver, client, httpCli, config, staticAddrs) if err != nil { return nil, err } @@ -191,15 +192,24 @@ func StartBackendObserver(eventReceiver BackendEventReceiver, client *clientv3.C } // NewBackendObserver creates a BackendObserver. -func NewBackendObserver(eventReceiver BackendEventReceiver, client *clientv3.Client, config *HealthCheckConfig, staticAddrs []string) (*BackendObserver, error) { +func NewBackendObserver(eventReceiver BackendEventReceiver, client *clientv3.Client, httpCli *http.Client, config *HealthCheckConfig, staticAddrs []string) (*BackendObserver, error) { if client == nil && len(staticAddrs) == 0 { return nil, ErrNoInstanceToSelect } + if httpCli == nil { + httpCli = http.DefaultClient + } + httpTLS := false + if v, ok := httpCli.Transport.(*http.Transport); ok && v != nil && v.TLSClientConfig != nil { + httpTLS = true + } bo := &BackendObserver{ config: config, curBackendInfo: make(map[string]BackendStatus), allBackendInfo: make(map[string]*BackendInfo), client: client, + httpCli: httpCli, + httpTLS: httpTLS, staticAddrs: staticAddrs, eventReceiver: eventReceiver, } @@ -381,11 +391,15 @@ func (bo *BackendObserver) checkHealth(ctx context.Context, backends map[string] // When a backend gracefully shut down, the status port returns 500 but the SQL port still accepts // new connections, so we must check the status port first. - url := fmt.Sprintf("http://%s:%d%s", info.IP, info.StatusPort, statusPathSuffix) + schema := "http" + if bo.httpTLS { + schema = "https" + } + url := fmt.Sprintf("%s://%s:%d%s", schema, info.IP, info.StatusPort, statusPathSuffix) var resp *http.Response err := connectWithRetry(func() error { var err error - if resp, err = http.Get(url); err == nil { + if resp, err = bo.httpCli.Get(url); err == nil { if err := resp.Body.Close(); err != nil { logutil.Logger(ctx).Warn("close http response in health check failed", zap.Error(err)) } diff --git a/pkg/manager/router/backend_observer_test.go b/pkg/manager/router/backend_observer_test.go index cca2777d..d2252f5f 100644 --- a/pkg/manager/router/backend_observer_test.go +++ b/pkg/manager/router/backend_observer_test.go @@ -28,6 +28,7 @@ import ( "time" "github.com/pingcap/TiProxy/lib/config" + "github.com/pingcap/TiProxy/lib/util/logger" "github.com/pingcap/TiProxy/lib/util/waitgroup" "github.com/pingcap/tidb/domain/infosync" "github.com/stretchr/testify/require" @@ -65,7 +66,7 @@ func TestObserveBackends(t *testing.T) { runTest(t, func(etcd *embed.Etcd, kv clientv3.KV, bo *BackendObserver, backendChan chan map[string]BackendStatus) { bo.Start() - backend1 := addBackend(t, kv, backendChan) + backend1 := addBackend(t, kv) checkStatus(t, backendChan, backend1, StatusHealthy) addFakeTopology(t, kv, backend1.sqlAddr) backend1.stopSQLServer() @@ -81,7 +82,7 @@ func TestObserveBackends(t *testing.T) { backend1.startHTTPServer() checkStatus(t, backendChan, backend1, StatusHealthy) - backend2 := addBackend(t, kv, backendChan) + backend2 := addBackend(t, kv) checkStatus(t, backendChan, backend2, StatusHealthy) removeBackend(t, kv, backend2) checkStatus(t, backendChan, backend2, StatusCannotConnect) @@ -170,7 +171,7 @@ func TestCancelObserver(t *testing.T) { runTest(t, func(etcd *embed.Etcd, kv clientv3.KV, bo *BackendObserver, backendChan chan map[string]BackendStatus) { backends := make([]*backendServer, 0, 3) for i := 0; i < 3; i++ { - backends = append(backends, addBackend(t, kv, backendChan)) + backends = append(backends, addBackend(t, kv)) } err := bo.fetchBackendList(context.Background()) require.NoError(t, err) @@ -214,7 +215,7 @@ func runTest(t *testing.T, f func(etcd *embed.Etcd, kv clientv3.KV, bo *BackendO kv := clientv3.NewKV(client) backendChan := make(chan map[string]BackendStatus, 1) mer := newMockEventReceiver(backendChan) - bo, err := NewBackendObserver(mer, client, newHealthCheckConfigForTest(), nil) + bo, err := NewBackendObserver(mer, client, nil, newHealthCheckConfigForTest(), nil) require.NoError(t, err) f(etcd, kv, bo, backendChan) bo.Close() @@ -241,7 +242,7 @@ func createEtcdClient(t *testing.T, etcd *embed.Etcd) *clientv3.Client { PDAddrs: etcd.Clients[0].Addr().String(), }, } - client, err := InitEtcdClient(cfg) + client, err := InitEtcdClient(logger.CreateLoggerForTest(t), cfg) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, client.Close()) @@ -332,7 +333,7 @@ func startListener(t *testing.T, addr string) (net.Listener, string) { } // A new healthy backend is added. -func addBackend(t *testing.T, kv clientv3.KV, backendChan chan map[string]BackendStatus) *backendServer { +func addBackend(t *testing.T, kv clientv3.KV) *backendServer { backend := &backendServer{ t: t, } diff --git a/pkg/manager/router/router.go b/pkg/manager/router/router.go index e644f667..7273b0ed 100644 --- a/pkg/manager/router/router.go +++ b/pkg/manager/router/router.go @@ -17,6 +17,7 @@ package router import ( "container/list" "context" + "net/http" "sync" "time" @@ -116,13 +117,13 @@ type ScoreBasedRouter struct { } // NewScoreBasedRouter creates a ScoreBasedRouter. -func NewScoreBasedRouter(cfg *config.BackendNamespace, client *clientv3.Client) (*ScoreBasedRouter, error) { +func NewScoreBasedRouter(cfg *config.BackendNamespace, client *clientv3.Client, httpCli *http.Client) (*ScoreBasedRouter, error) { router := &ScoreBasedRouter{ backends: list.New(), } router.Lock() defer router.Unlock() - observer, err := StartBackendObserver(router, client, newDefaultHealthCheckConfig(), cfg.Instances) + observer, err := StartBackendObserver(router, client, httpCli, newDefaultHealthCheckConfig(), cfg.Instances) if err != nil { return nil, err } diff --git a/pkg/manager/router/router_test.go b/pkg/manager/router/router_test.go index e9553736..8b739932 100644 --- a/pkg/manager/router/router_test.go +++ b/pkg/manager/router/router_test.go @@ -523,7 +523,7 @@ func TestConcurrency(t *testing.T) { // We create other goroutines to change backends easily. etcd := createEtcdServer(t, "127.0.0.1:0") client := createEtcdClient(t, etcd) - router, err := NewScoreBasedRouter(cfg, client) + router, err := NewScoreBasedRouter(cfg, client, nil) require.NoError(t, err) var wg waitgroup.WaitGroup diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index f2fa6b53..37667210 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -18,9 +18,10 @@ import ( "crypto/tls" "encoding/binary" "fmt" + "net" - pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/TiProxy/lib/util/errors" + pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/util/hack" ) @@ -39,6 +40,7 @@ type Authenticator struct { dbname string // default database name capability uint32 // client capability collation uint8 + serverAddr string attrs []byte // no need to parse backendTLSConfig *tls.Config } @@ -48,7 +50,7 @@ func (auth *Authenticator) String() string { auth.user, auth.dbname, auth.capability, auth.collation) } -func (auth *Authenticator) handshakeFirstTime(clientIO, backendIO *pnet.PacketIO, serverTLSConfig, backendTLSConfig *tls.Config) error { +func (auth *Authenticator) handshakeFirstTime(clientIO, backendIO *pnet.PacketIO, frontendTLSConfig, backendTLSConfig *tls.Config) error { backendIO.ResetSequence() // Read initial handshake packet from the backend. serverPkt, serverCapability, err := auth.readInitialHandshake(backendIO) @@ -76,7 +78,7 @@ func (auth *Authenticator) handshakeFirstTime(clientIO, backendIO *pnet.PacketIO sslEnabled := uint32(clientCapability)&mysql.ClientSSL > 0 if sslEnabled { // Upgrade TLS with the client if SSL is enabled. - if _, err = clientIO.UpgradeToServerTLS(serverTLSConfig); err != nil { + if _, err = clientIO.UpgradeToServerTLS(frontendTLSConfig); err != nil { return err } } else { @@ -90,8 +92,19 @@ func (auth *Authenticator) handshakeFirstTime(clientIO, backendIO *pnet.PacketIO return err } // Always upgrade TLS with the server. - auth.backendTLSConfig = backendTLSConfig - if err = backendIO.UpgradeToClientTLS(backendTLSConfig); err != nil { + auth.backendTLSConfig = backendTLSConfig.Clone() + addr := backendIO.RemoteAddr().String() + if auth.serverAddr != "" { + // NOTE: should use DNS name as much as possible + // Usally certs are signed with domain instead of IP addrs + // And `RemoteAddr()` will return IP addr + addr = auth.serverAddr + } + host, _, err := net.SplitHostPort(addr) + if err == nil { + auth.backendTLSConfig.ServerName = host + } + if err = backendIO.UpgradeToClientTLS(auth.backendTLSConfig); err != nil { return err } if sslEnabled { diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 79a504c8..75b8db88 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -102,7 +102,7 @@ func (mgr *BackendConnManager) ConnectionID() uint64 { } // Connect connects to the first backend and then start watching redirection signals. -func (mgr *BackendConnManager) Connect(ctx context.Context, serverAddr string, clientIO *pnet.PacketIO, serverTLSConfig, backendTLSConfig *tls.Config) error { +func (mgr *BackendConnManager) Connect(ctx context.Context, serverAddr string, clientIO *pnet.PacketIO, frontendTLSConfig, backendTLSConfig *tls.Config) error { mgr.processLock.Lock() defer mgr.processLock.Unlock() mgr.backendConn = NewBackendConnection(serverAddr) @@ -110,7 +110,8 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, serverAddr string, c return err } backendIO := mgr.backendConn.PacketIO() - if err := mgr.authenticator.handshakeFirstTime(clientIO, backendIO, serverTLSConfig, backendTLSConfig); err != nil { + mgr.authenticator.serverAddr = serverAddr + if err := mgr.authenticator.handshakeFirstTime(clientIO, backendIO, frontendTLSConfig, backendTLSConfig); err != nil { return err } mgr.cmdProcessor.capability = mgr.authenticator.capability diff --git a/pkg/proxy/client/client_conn.go b/pkg/proxy/client/client_conn.go index 3f47f790..422a1cb2 100644 --- a/pkg/proxy/client/client_conn.go +++ b/pkg/proxy/client/client_conn.go @@ -28,24 +28,24 @@ import ( ) type ClientConnection struct { - logger *zap.Logger - serverTLSConfig *tls.Config // the TLS config to connect to clients. - backendTLSConfig *tls.Config // the TLS config to connect to TiDB server. - pkt *pnet.PacketIO // a helper to read and write data in packet format. - nsmgr *namespace.NamespaceManager - ns *namespace.Namespace - connMgr *backend.BackendConnManager + logger *zap.Logger + frontendTLSConfig *tls.Config // the TLS config to connect to clients. + backendTLSConfig *tls.Config // the TLS config to connect to TiDB server. + pkt *pnet.PacketIO // a helper to read and write data in packet format. + nsmgr *namespace.NamespaceManager + ns *namespace.Namespace + connMgr *backend.BackendConnManager } -func NewClientConnection(logger *zap.Logger, conn net.Conn, serverTLSConfig *tls.Config, backendTLSConfig *tls.Config, nsmgr *namespace.NamespaceManager, bemgr *backend.BackendConnManager) *ClientConnection { +func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *tls.Config, backendTLSConfig *tls.Config, nsmgr *namespace.NamespaceManager, bemgr *backend.BackendConnManager) *ClientConnection { pkt := pnet.NewPacketIO(conn) return &ClientConnection{ - logger: logger, - serverTLSConfig: serverTLSConfig, - backendTLSConfig: backendTLSConfig, - pkt: pkt, - nsmgr: nsmgr, - connMgr: bemgr, + logger: logger, + frontendTLSConfig: frontendTLSConfig, + backendTLSConfig: backendTLSConfig, + pkt: pkt, + nsmgr: nsmgr, + connMgr: bemgr, } } @@ -64,7 +64,7 @@ func (cc *ClientConnection) connectBackend(ctx context.Context) error { if err != nil { return err } - if err = cc.connMgr.Connect(ctx, addr, cc.pkt, cc.serverTLSConfig, cc.backendTLSConfig); err != nil { + if err = cc.connMgr.Connect(ctx, addr, cc.pkt, cc.frontendTLSConfig, cc.backendTLSConfig); err != nil { return err } return nil diff --git a/pkg/proxy/net/tls.go b/pkg/proxy/net/tls.go index 1bf91d91..d9d1700c 100644 --- a/pkg/proxy/net/tls.go +++ b/pkg/proxy/net/tls.go @@ -16,7 +16,6 @@ package net import ( "crypto/tls" - "net" "github.com/pingcap/TiProxy/lib/util/errors" ) @@ -34,10 +33,6 @@ func (p *PacketIO) UpgradeToServerTLS(tlsConfig *tls.Config) (tls.ConnectionStat func (p *PacketIO) UpgradeToClientTLS(tlsConfig *tls.Config) error { tlsConfig = tlsConfig.Clone() - host, _, err := net.SplitHostPort(p.conn.RemoteAddr().String()) - if err == nil { - tlsConfig.ServerName = host - } tlsConn := tls.Client(p.conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { return errors.WithStack(errors.Wrap(ErrHandshakeTLS, err)) diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index ecf40a20..77f29b31 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -40,18 +40,18 @@ type serverState struct { } type SQLServer struct { - listener net.Listener - logger *zap.Logger - nsmgr *mgrns.NamespaceManager - serverTLSConfig *tls.Config - clusterTLSConfig *tls.Config - wg waitgroup.WaitGroup + listener net.Listener + logger *zap.Logger + nsmgr *mgrns.NamespaceManager + frontendTLSConfig *tls.Config + backendTLSConfig *tls.Config + wg waitgroup.WaitGroup mu serverState } // NewSQLServer creates a new SQLServer. -func NewSQLServer(logger *zap.Logger, workdir string, cfg config.ProxyServer, scfg config.Security, nsmgr *mgrns.NamespaceManager) (*SQLServer, error) { +func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, scfg config.Security, nsmgr *mgrns.NamespaceManager) (*SQLServer, error) { var err error s := &SQLServer{ @@ -65,10 +65,10 @@ func NewSQLServer(logger *zap.Logger, workdir string, cfg config.ProxyServer, sc }, } - if s.serverTLSConfig, err = security.CreateServerTLSConfig(logger, scfg.Server.CA, scfg.Server.Key, scfg.Server.Cert, scfg.RSAKeySize, workdir); err != nil { + if s.frontendTLSConfig, err = security.BuildServerTLSConfig(logger, scfg.ServerTLS); err != nil { return nil, err } - if s.clusterTLSConfig, err = security.CreateClientTLSConfig(scfg.Cluster.CA, scfg.Cluster.Key, scfg.Cluster.Cert); err != nil { + if s.backendTLSConfig, err = security.BuildClientTLSConfig(logger, scfg.SQLTLS); err != nil { return nil, err } @@ -125,7 +125,7 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) { connID := s.mu.connID s.mu.connID++ logger := s.logger.With(zap.Uint64("connID", connID)) - clientConn := client.NewClientConnection(logger.Named("cliconn"), conn, s.serverTLSConfig, s.clusterTLSConfig, s.nsmgr, backend.NewBackendConnManager(logger.Named("bemgr"), connID)) + clientConn := client.NewClientConnection(logger.Named("cliconn"), conn, s.frontendTLSConfig, s.backendTLSConfig, s.nsmgr, backend.NewBackendConnManager(logger.Named("bemgr"), connID)) s.mu.clients[connID] = clientConn s.mu.Unlock() diff --git a/pkg/server/server.go b/pkg/server/server.go index d8f11aea..eaaf5bc2 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -21,13 +21,13 @@ import ( "net/url" "path/filepath" "strconv" - "strings" "time" ginzap "github.com/gin-contrib/zap" "github.com/gin-gonic/gin" "github.com/pingcap/TiProxy/lib/config" "github.com/pingcap/TiProxy/lib/util/errors" + "github.com/pingcap/TiProxy/lib/util/security" "github.com/pingcap/TiProxy/lib/util/waitgroup" mgrcfg "github.com/pingcap/TiProxy/pkg/manager/config" mgrns "github.com/pingcap/TiProxy/pkg/manager/namespace" @@ -45,21 +45,34 @@ type Server struct { // managers ConfigManager *mgrcfg.ConfigManager NamespaceManager *mgrns.NamespaceManager - ObserverClient *clientv3.Client MetricsManager *metrics.MetricsManager - + ObserverClient *clientv3.Client + // HTTP client + Http *http.Client // HTTP/GRPC services Etcd *embed.Etcd - // L7 proxy Proxy *proxy.SQLServer } func NewServer(ctx context.Context, cfg *config.Config, logger *zap.Logger, pubAddr string) (srv *Server, err error) { + { + tlogger := logger.Named("tls") + // auto generate CA for serverTLS will break + if uerr := security.AutoTLS(tlogger, &cfg.Security.ServerTLS, false, cfg.Workdir, "server", cfg.Security.RSAKeySize); uerr != nil { + err = errors.WithStack(uerr) + return + } + if uerr := security.AutoTLS(tlogger, &cfg.Security.PeerTLS, true, cfg.Workdir, "peer", cfg.Security.RSAKeySize); uerr != nil { + err = errors.WithStack(uerr) + return + } + } + srv = &Server{ ConfigManager: mgrcfg.NewConfigManager(), - NamespaceManager: mgrns.NewNamespaceManager(), MetricsManager: metrics.NewMetricsManager(), + NamespaceManager: mgrns.NewNamespaceManager(), } ready := atomic.NewBool(false) @@ -95,7 +108,7 @@ func NewServer(ctx context.Context, cfg *config.Config, logger *zap.Logger, pubA // 2. pass down '*Server' struct such that the underlying relies on the pointer only. But it does not work well for golang. To avoid cyclic imports between 'api' and `server` packages, two packages needs to be merged. That is basically what happened to TiDB '*Session'. api.Register(engine.Group("/api"), ready, cfg.API, logger.Named("api"), srv.NamespaceManager, srv.ConfigManager) - srv.Etcd, err = buildEtcd(ctx, cfg, logger, pubAddr, engine) + srv.Etcd, err = buildEtcd(ctx, cfg, logger.Named("etcd"), pubAddr, engine) if err != nil { err = errors.WithStack(err) return @@ -110,13 +123,23 @@ func NewServer(ctx context.Context, cfg *config.Config, logger *zap.Logger, pubA } } - // setup config manager + // general cluster HTTP client { - addrs := make([]string, len(srv.Etcd.Clients)) - for i := range addrs { - addrs[i] = srv.Etcd.Clients[i].Addr().String() + clientTLS, uerr := security.BuildClientTLSConfig(logger.Named("http"), cfg.Security.ClusterTLS) + if uerr != nil { + err = errors.WithStack(err) + return + } + srv.Http = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: clientTLS, + }, } - err = srv.ConfigManager.Init(ctx, addrs, cfg.Advance, logger.Named("config")) + } + + // setup config manager + { + err = srv.ConfigManager.Init(ctx, srv.Etcd.Server.KV(), cfg.Advance, logger.Named("config")) if err != nil { err = errors.WithStack(err) return @@ -144,7 +167,7 @@ func NewServer(ctx context.Context, cfg *config.Config, logger *zap.Logger, pubA // setup namespace manager { - srv.ObserverClient, err = router.InitEtcdClient(cfg) + srv.ObserverClient, err = router.InitEtcdClient(logger.Named("pd"), cfg) if err != nil { err = errors.WithStack(err) return @@ -157,7 +180,7 @@ func NewServer(ctx context.Context, cfg *config.Config, logger *zap.Logger, pubA return } - err = srv.NamespaceManager.Init(logger.Named("nsmgr"), nss, srv.ObserverClient) + err = srv.NamespaceManager.Init(logger.Named("nsmgr"), nss, srv.ObserverClient, srv.Http) if err != nil { err = errors.WithStack(err) return @@ -166,7 +189,7 @@ func NewServer(ctx context.Context, cfg *config.Config, logger *zap.Logger, pubA // setup proxy server { - srv.Proxy, err = proxy.NewSQLServer(logger.Named("proxy"), cfg.Workdir, cfg.Proxy, cfg.Security, srv.NamespaceManager) + srv.Proxy, err = proxy.NewSQLServer(logger.Named("proxy"), cfg.Proxy, cfg.Security, srv.NamespaceManager) if err != nil { err = errors.WithStack(err) return @@ -226,15 +249,19 @@ func (s *Server) Close() error { func buildEtcd(ctx context.Context, cfg *config.Config, logger *zap.Logger, pubAddr string, engine *gin.Engine) (srv *embed.Etcd, err error) { etcd_cfg := embed.NewConfig() - apiAddrStr := cfg.API.Addr - if !strings.HasPrefix(apiAddrStr, "http://") { - apiAddrStr = fmt.Sprintf("http://%s", apiAddrStr) - } - apiAddr, uerr := url.Parse(apiAddrStr) - if uerr != nil { - err = errors.WithStack(uerr) + if etcd_cfg.ClientTLSInfo, etcd_cfg.PeerTLSInfo, err = security.BuildEtcdTLSConfig(logger, cfg.Security.ServerTLS, cfg.Security.PeerTLS); err != nil { return } + + apiAddr, err := url.Parse(fmt.Sprintf("http://%s", cfg.API.Addr)) + if err != nil { + return nil, err + } + if etcd_cfg.ClientTLSInfo.Empty() { + apiAddr.Scheme = "http" + } else { + apiAddr.Scheme = "https" + } etcd_cfg.LCUrls = []url.URL{*apiAddr} apiAddrAdvertise := *apiAddr apiAddrAdvertise.Host = fmt.Sprintf("%s:%s", pubAddr, apiAddrAdvertise.Port()) @@ -242,24 +269,29 @@ func buildEtcd(ctx context.Context, cfg *config.Config, logger *zap.Logger, pubA peerPort := cfg.Advance.PeerPort if peerPort == "" { - peerPortNum, uerr := strconv.Atoi(apiAddr.Port()) - if uerr != nil { - err = errors.WithStack(uerr) - return + peerPortNum, err := strconv.Atoi(apiAddr.Port()) + if err != nil { + return nil, err } peerPort = strconv.Itoa(peerPortNum + 1) } peerAddr := *apiAddr + if etcd_cfg.PeerTLSInfo.Empty() { + peerAddr.Scheme = "http" + } else { + peerAddr.Scheme = "https" + } peerAddr.Host = fmt.Sprintf("%s:%s", peerAddr.Hostname(), peerPort) etcd_cfg.LPUrls = []url.URL{peerAddr} - peerAddrAdvertise := *apiAddr + peerAddrAdvertise := peerAddr peerAddrAdvertise.Host = fmt.Sprintf("%s:%s", pubAddr, peerPort) etcd_cfg.APUrls = []url.URL{peerAddrAdvertise} etcd_cfg.Name = "proxy-" + fmt.Sprint(time.Now().UnixMicro()) etcd_cfg.InitialCluster = etcd_cfg.InitialClusterFromName(etcd_cfg.Name) etcd_cfg.Dir = filepath.Join(cfg.Workdir, "etcd") - etcd_cfg.ZapLoggerBuilder = embed.NewZapLoggerBuilder(logger.Named("etcd")) + etcd_cfg.ZapLoggerBuilder = embed.NewZapLoggerBuilder(logger) + etcd_cfg.UserHandlers = map[string]http.Handler{ "/api/": engine, }