Skip to content

Commit

Permalink
proxy: enable non-TLS connection to TiDB (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhebox authored Oct 19, 2022
1 parent 9917e02 commit 9542e0c
Show file tree
Hide file tree
Showing 17 changed files with 147 additions and 95 deletions.
20 changes: 12 additions & 8 deletions cmd/tiproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/TiProxy/lib/util/cmd"
"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/TiProxy/lib/util/waitgroup"
"github.com/pingcap/TiProxy/pkg/sctx"
"github.com/pingcap/TiProxy/pkg/server"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -60,16 +61,19 @@ func main() {
cfg.Log.Level = *logLevel
}

cfg.Cluster = config.Cluster{
PubAddr: *pubAddr,
ClusterName: *clusterName,
NodeName: *nodeName,
BootstrapDurl: *bootstrapDiscoveryUrl,
BootstrapDdns: *bootstrapDiscoveryDNS,
BootstrapClusters: *bootstrapClusters,
sctx := &sctx.Context{
Config: cfg,
Cluster: sctx.Cluster{
PubAddr: *pubAddr,
ClusterName: *clusterName,
NodeName: *nodeName,
BootstrapDurl: *bootstrapDiscoveryUrl,
BootstrapDdns: *bootstrapDiscoveryDNS,
BootstrapClusters: *bootstrapClusters,
},
}

srv, err := server.NewServer(cmd.Context(), cfg)
srv, err := server.NewServer(cmd.Context(), sctx)
if err != nil {
return errors.Wrapf(err, "fail to create server")
}
Expand Down
1 change: 1 addition & 0 deletions conf/proxy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ proxy:
tcp-keep-alive: true
max-connections: 1000
pd-addrs: "127.0.0.1:2379"
# require-backend-tls: true
# proxy-protocol: "v2"
metrics:
api:
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
FROM alpine:latest
FROM alpine:edge

EXPOSE 3080
EXPOSE 3081
EXPOSE 6000

ADD . https://raw.githubusercontent.com/njhallett/apk-fastest-mirror/main/apk-fastest-mirror.sh /proxy
ADD . https://raw.githubusercontent.com/njhallett/apk-fastest-mirror/c4ca44caef3385d830fea34df2dbc2ba4a17e021/apk-fastest-mirror.sh /proxy
RUN sh ./proxy/apk-fastest-mirror.sh -t 50 && apk add --no-cache --progress git make go
ARG BUILDFLAGS
ARG GOPROXY
Expand Down
12 changes: 2 additions & 10 deletions lib/config/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,6 @@ type Config struct {
Security Security `yaml:"security,omitempty" toml:"security,omitempty" json:"security,omitempty"`
Metrics Metrics `yaml:"metrics,omitempty" toml:"metrics,omitempty" json:"metrics,omitempty"`
Log Log `yaml:"log,omitempty" toml:"log,omitempty" json:"log,omitempty"`
Cluster Cluster `yaml:"-" toml:"-" json:"-"`
}

type Cluster struct {
PubAddr string
ClusterName string
NodeName string
BootstrapDurl string
BootstrapDdns string
BootstrapClusters []string
}

type Metrics struct {
Expand All @@ -61,6 +51,7 @@ type ProxyServerOnline struct {
type ProxyServer struct {
Addr string `yaml:"addr,omitempty" toml:"addr,omitempty" json:"addr,omitempty"`
PDAddrs string `yaml:"pd-addrs,omitempty" toml:"pd-addrs,omitempty" json:"pd-addrs,omitempty"`
RequireBackendTLS bool `yaml:"require-backend-tls,omitempty" toml:"require-backend-tls,omitempty" json:"require-backend-tls,omitempty"`
ProxyServerOnline `yaml:",inline" toml:",inline" json:",inline"`
}

Expand Down Expand Up @@ -121,6 +112,7 @@ type Security struct {
func NewConfig(data []byte) (*Config, error) {
var cfg Config
cfg.Advance.IgnoreWrongNamespace = true
cfg.Proxy.RequireBackendTLS = true
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions lib/config/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ var testProxyConfig = Config{
WatchInterval: "30m",
},
Proxy: ProxyServer{
Addr: "0.0.0.0:4000",
PDAddrs: "127.0.0.1:4089",
Addr: "0.0.0.0:4000",
PDAddrs: "127.0.0.1:4089",
RequireBackendTLS: true,
ProxyServerOnline: ProxyServerOnline{
MaxConnections: 1,
TCPKeepAlive: true,
Expand Down
8 changes: 6 additions & 2 deletions lib/util/errors/werror.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ func (e *WError) Error() string {
}

func (e *WError) Is(s error) bool {
return errors.Is(e.cerr, s)
r := errors.Is(e.cerr, s)
if r {
return r
}
return errors.Is(e.uerr, s)
}

func (e *WError) Unwrap() error {
Expand All @@ -60,7 +64,7 @@ func (e *WError) Unwrap() error {
// Wrap is used to wrapping unknown errors. A typical example is that:
// 1. have a function `ReadMyConfig()`
// 2. it got errors returned from external libraries
// 3. you want to wrap these errors, expect `Unwrap(err) == ErrExternalErrors && Is(err, ErrReadMyConfig)`.
// 3. you want to wrap these errors, expect `Unwrap(err) == ErrExternalErrors && Is(err, ErrReadMyConfig) && Is(err, ErrExternalErrors)`.
// 4. then you are finding `err := Wrap(ErrReadMyConfig, ErrExternalErrors)`
// Note that wrap nil error will get nil error.
func Wrap(cerr error, uerr error) error {
Expand Down
1 change: 1 addition & 0 deletions lib/util/errors/werror_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func TestWrapf(t *testing.T) {
e2 := serr.New("dd")
e := serr.Wrapf(e1, "%w: 4", e2)
require.ErrorIsf(t, e, e1, "equal to the external error")
require.ErrorIsf(t, e, e2, "equal to the underlying error, too")
require.ErrorAsf(t, e, &e2, "unwrapping to the internal error")

require.Nil(t, serr.Wrapf(nil, ""), "wrap nil got nil")
Expand Down
62 changes: 37 additions & 25 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ var (
)

const requiredFrontendCaps = pnet.ClientProtocol41
const requiredBackendCaps = pnet.ClientDeprecateEOF | pnet.ClientSSL
const defRequiredBackendCaps = pnet.ClientDeprecateEOF

// Other server capabilities are not supported. ClientDeprecateEOF is supported but TiDB 6.2.0 doesn't support it now.
const supportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientConnectWithDB |
pnet.ClientODBC | pnet.ClientLocalFiles | pnet.ClientInteractive | pnet.ClientLongFlag |
pnet.ClientODBC | pnet.ClientLocalFiles | pnet.ClientInteractive | pnet.ClientLongFlag | pnet.ClientSSL |
pnet.ClientTransactions | pnet.ClientReserved | pnet.ClientSecureConnection | pnet.ClientMultiStatements |
pnet.ClientMultiResults | pnet.ClientPluginAuth | pnet.ClientConnectAttrs | pnet.ClientPluginAuthLenencClientData |
requiredFrontendCaps | requiredBackendCaps
requiredFrontendCaps | defRequiredBackendCaps

// Authenticator handshakes with the client and the backend.
type Authenticator struct {
Expand All @@ -52,6 +52,7 @@ type Authenticator struct {
capability uint32 // client capability
collation uint8
proxyProtocol bool
requireBackendTLS bool
}

func (auth *Authenticator) String() string {
Expand Down Expand Up @@ -140,11 +141,16 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO, back
return err
}
backendCapability := pnet.Capability(backendCapabilityU)
requiredBackendCaps := defRequiredBackendCaps
if auth.requireBackendTLS {
requiredBackendCaps |= pnet.ClientSSL
}

if commonCaps := backendCapability & requiredBackendCaps; commonCaps != requiredBackendCaps {
// The error cannot be sent to the client because the client only expects an initial handshake packet.
// The only way is to log it and disconnect.
logger.Error("require backend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredBackendCaps))
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps&^commonCaps)
logger.Error("require backend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredBackendCaps^commonCaps))
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps^commonCaps)
}
if common := proxyCapability & backendCapability; (proxyCapability^common)&^pnet.ClientSSL != 0 {
// TODO: need to do negotiation with backend
Expand All @@ -158,28 +164,34 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO, back
logger.Info("backend does not support capabilities from proxy", zap.Stringer("common", common), zap.Stringer("proxy", proxyCapability^common), zap.Stringer("backend", backendCapability^common))
}

resp.Capability = auth.capability | mysql.ClientSSL
// Send an unknown auth plugin so that the backend will request the auth data again.
resp.AuthPlugin = "auth_unknown_plugin"
pkt = pnet.MakeHandshakeResponse(resp)
// write SSL Packet
if err := backendIO.WritePacket(pkt[:32], true); err != nil {
return err
}
auth.backendTLSConfig = backendTLSConfig.Clone()
addr := backendIO.RemoteAddr().String()
if auth.serverAddr != "" {
// NOTE: should use DNS name as much as possible
// Usually certs are signed with domain instead of IP addrs
// And `RemoteAddr()` will return IP addr
addr = auth.serverAddr
}
host, _, err := net.SplitHostPort(addr)
if err == nil {
auth.backendTLSConfig.ServerName = host
}
if err = backendIO.ClientTLSHandshake(auth.backendTLSConfig); err != nil {
return err
resp.Capability = auth.capability

if backendCapability&pnet.ClientSSL != 0 {
resp.Capability |= mysql.ClientSSL
pkt = pnet.MakeHandshakeResponse(resp)
// write SSL Packet
if err := backendIO.WritePacket(pkt[:32], true); err != nil {
return err
}
auth.backendTLSConfig = backendTLSConfig.Clone()
addr := backendIO.RemoteAddr().String()
if auth.serverAddr != "" {
// NOTE: should use DNS name as much as possible
// Usually certs are signed with domain instead of IP addrs
// And `RemoteAddr()` will return IP addr
addr = auth.serverAddr
}
host, _, err := net.SplitHostPort(addr)
if err == nil {
auth.backendTLSConfig.ServerName = host
}
if err = backendIO.ClientTLSHandshake(auth.backendTLSConfig); err != nil {
return err
}
} else {
pkt = pnet.MakeHandshakeResponse(resp)
}

// forward client handshake resp
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestUnsupportedCapability(t *testing.T) {
for _, cfgs := range cfgOverriders {
ts, clean := newTestSuite(t, tc, cfgs...)
ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite) {
if ts.mb.backendConfig.capability&requiredBackendCaps.Uint32() != requiredBackendCaps.Uint32() {
if ts.mb.backendConfig.capability&defRequiredBackendCaps.Uint32() != defRequiredBackendCaps.Uint32() {
require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation)
} else if ts.mc.clientConfig.capability&requiredFrontendCaps.Uint32() != requiredFrontendCaps.Uint32() {
require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation)
Expand Down
4 changes: 2 additions & 2 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ type BackendConnManager struct {
}

// NewBackendConnManager creates a BackendConnManager.
func NewBackendConnManager(logger *zap.Logger, connectionID uint64, proxyProtocol bool) *BackendConnManager {
func NewBackendConnManager(logger *zap.Logger, connectionID uint64, proxyProtocol, requireBackendTLS bool) *BackendConnManager {
return &BackendConnManager{
logger: logger,
connectionID: connectionID,
cmdProcessor: NewCmdProcessor(),
authenticator: &Authenticator{supportedServerCapabilities: supportedServerCapabilities, proxyProtocol: proxyProtocol},
authenticator: &Authenticator{supportedServerCapabilities: supportedServerCapabilities, proxyProtocol: proxyProtocol, requireBackendTLS: requireBackendTLS},
signalReceived: make(chan struct{}, 1),
redirectResCh: make(chan *redirectResult, 1),
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/backend/mock_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy {
mp := &mockProxy{
proxyConfig: cfg,
logger: logger.CreateLoggerForTest(t).Named("mockProxy"),
BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), 0, false),
BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), 0, false, false),
}
mp.cmdProcessor.capability = cfg.capability
return mp
Expand Down
6 changes: 3 additions & 3 deletions pkg/proxy/client/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ type ClientConnection struct {
connMgr *backend.BackendConnManager
}

func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *tls.Config, backendTLSConfig *tls.Config, nsmgr *namespace.NamespaceManager, connID uint64, proxyProtocol bool) *ClientConnection {
bemgr := backend.NewBackendConnManager(logger.Named("be"), connID, proxyProtocol)
func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *tls.Config, backendTLSConfig *tls.Config, nsmgr *namespace.NamespaceManager, connID uint64, proxyProtocol, requireBackendTLS bool) *ClientConnection {
bemgr := backend.NewBackendConnManager(logger.Named("be"), connID, proxyProtocol, requireBackendTLS)
opts := make([]pnet.PacketIOption, 0, 2)
opts = append(opts, pnet.WithClient)
opts = append(opts, pnet.WithWrapError(ErrClientConn))
if proxyProtocol {
opts = append(opts, pnet.WithProxy)
}
Expand Down
4 changes: 1 addition & 3 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ import (

var (
errInvalidSequence = dbterror.ClassServer.NewStd(errno.ErrInvalidSequence)
ErrClientConn = errors.New("this is an error from client")

proxyV2Magic = []byte{0xD, 0xA, 0xD, 0xA, 0x0, 0xD, 0xA, 0x51, 0x55, 0x49, 0x54, 0xA}
proxyV2Magic = []byte{0xD, 0xA, 0xD, 0xA, 0x0, 0xD, 0xA, 0x51, 0x55, 0x49, 0x54, 0xA}
)

const (
Expand Down
6 changes: 4 additions & 2 deletions pkg/proxy/net/packetio_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ func WithProxy(pi *PacketIO) {
pi.proxyInited = atomic.NewBool(true)
}

func WithClient(pi *PacketIO) {
pi.wrap = ErrClientConn
func WithWrapError(err error) func(pi *PacketIO) {
return func(pi *PacketIO) {
pi.wrap = err
}
}
20 changes: 11 additions & 9 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ type serverState struct {
}

type SQLServer struct {
listener net.Listener
logger *zap.Logger
certMgr *cert.CertManager
nsmgr *mgrns.NamespaceManager
wg waitgroup.WaitGroup
listener net.Listener
logger *zap.Logger
certMgr *cert.CertManager
nsmgr *mgrns.NamespaceManager
requireBackendTLS bool
wg waitgroup.WaitGroup

mu serverState
}
Expand All @@ -54,9 +55,10 @@ func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, certMgr *cert.Cert
var err error

s := &SQLServer{
logger: logger,
certMgr: certMgr,
nsmgr: nsmgr,
logger: logger,
certMgr: certMgr,
nsmgr: nsmgr,
requireBackendTLS: cfg.RequireBackendTLS,
mu: serverState{
connID: 0,
clients: make(map[uint64]*client.ClientConnection),
Expand Down Expand Up @@ -122,7 +124,7 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) {
connID := s.mu.connID
s.mu.connID++
logger := s.logger.With(zap.Uint64("connID", connID), zap.String("remoteAddr", conn.RemoteAddr().String()))
clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(), s.nsmgr, connID, s.mu.proxyProtocol)
clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(), s.nsmgr, connID, s.mu.proxyProtocol, s.requireBackendTLS)
s.mu.clients[connID] = clientConn
s.mu.Unlock()

Expand Down
31 changes: 31 additions & 0 deletions pkg/sctx/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sctx

import "github.com/pingcap/TiProxy/lib/config"

type Cluster struct {
PubAddr string
ClusterName string
NodeName string
BootstrapDurl string
BootstrapDdns string
BootstrapClusters []string
}

type Context struct {
Config *config.Config
Cluster Cluster
}
Loading

0 comments on commit 9542e0c

Please sign in to comment.