From eb5b4b9c9e79cce452c3808d3bc42be69fa07e70 Mon Sep 17 00:00:00 2001 From: djshow832 Date: Thu, 7 Sep 2023 21:09:44 +0800 Subject: [PATCH] backend, net: Remind the user when TiDB sets wrong proxy-protocol (#362) --- pkg/proxy/backend/authenticator.go | 11 +++++- pkg/proxy/backend/authenticator_test.go | 49 +++++++++++++++++++++++++ pkg/proxy/backend/error.go | 9 +++-- pkg/proxy/backend/mock_backend_test.go | 7 +++- pkg/proxy/client/client_conn.go | 4 +- pkg/proxy/net/packetio.go | 4 +- pkg/proxy/proxy.go | 3 +- 7 files changed, 76 insertions(+), 11 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 2689e649..cdec1704 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -210,12 +210,19 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte // forward other packets pluginName := "" + pktIdx := 0 loop: for { serverPkt, err := forwardMsg(backendIO, clientIO) if err != nil { + // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence + // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence + if pktIdx == 0 && errors.Is(err, pnet.ErrInvalidSequence) { + return pnet.WrapUserError(err, checkPPV2ErrMsg) + } return err } + pktIdx++ switch serverPkt[0] { case pnet.OKHeader.Byte(): return nil @@ -334,7 +341,9 @@ func (auth *Authenticator) writeAuthHandshake( tcfg.ServerName = host } if err := backendIO.ClientTLSHandshake(tcfg); err != nil { - return err + // tiproxy pp enabled, tidb pp disabled, tls enabled => tls handshake encounters unrecognized packet + // tiproxy pp disabled, tidb pp enabled, tls enabled => tls handshake encounters unrecognized packet + return pnet.WrapUserError(err, checkPPV2ErrMsg) } } else { resp.Capability &= ^pnet.ClientSSL diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index c5da07c0..9fcbc89c 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -338,3 +338,52 @@ func TestRequireBackendTLS(t *testing.T) { clean() } } + +func TestProxyProtocol(t *testing.T) { + cfgs := [][]cfgOverrider{ + { + func(cfg *testConfig) { + cfg.proxyConfig.bcConfig.ProxyProtocol = true + }, + func(cfg *testConfig) { + cfg.proxyConfig.bcConfig.ProxyProtocol = false + }, + }, + { + func(cfg *testConfig) { + cfg.backendConfig.proxyProtocol = true + }, + func(cfg *testConfig) { + cfg.backendConfig.proxyProtocol = false + }, + }, + { + func(cfg *testConfig) { + cfg.proxyConfig.bcConfig.RequireBackendTLS = true + cfg.backendConfig.capability |= pnet.ClientSSL + }, + func(cfg *testConfig) { + cfg.proxyConfig.bcConfig.RequireBackendTLS = false + cfg.backendConfig.capability &= ^pnet.ClientSSL + }, + }, + } + + tc := newTCPConnSuite(t) + cfgOverriders := getCfgCombinations(cfgs) + for _, cfgs := range cfgOverriders { + ts, clean := newTestSuite(t, tc, cfgs...) + ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) { + // TiDB proxy-protocol can be set unfallbackable, but TiProxy proxy-protocol is always fallbackable. + // So when backend enables proxy-protocol and proxy disables it, it still works well. + if ts.mp.bcConfig.ProxyProtocol && !ts.mb.proxyProtocol { + var userError *pnet.UserError + require.True(t, errors.As(ts.mp.err, &userError)) + require.Equal(t, checkPPV2ErrMsg, userError.UserMsg()) + } else { + require.NoError(t, ts.mp.err) + } + }) + clean() + } +} diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index b2eb3295..4097126a 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -8,12 +8,13 @@ import ( ) const ( - connectErrMsg = "No available TiDB instances, please check TiDB cluster" + connectErrMsg = "No available TiDB instances, please make sure TiDB is available" parsePktErrMsg = "TiProxy fails to parse the packet, please contact PingCAP" - handshakeErrMsg = "TiProxy fails to connect to TiDB, please check network" + handshakeErrMsg = "TiProxy fails to connect to TiDB, please make sure TiDB is available" capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB" - requireProxyTLSErrMsg = "Require TLS config on TiProxy when require-backend-tls=true" - requireTiDBTLSErrMsg = "Require TLS config on TiDB when require-backend-tls=true" + requireProxyTLSErrMsg = "Require TLS enabled on TiProxy when require-backend-tls=true" + requireTiDBTLSErrMsg = "Require TLS enabled on TiDB when require-backend-tls=true" + checkPPV2ErrMsg = "TiProxy fails to connect to TiDB, please make sure TiDB proxy-protocol is set correctly. If this error still exists, please contact PingCAP" ) var ( diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index 9a8ec1bc..d272db49 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -24,6 +24,7 @@ type backendConfig struct { stmtNum int capability pnet.Capability status uint16 + proxyProtocol bool authSucceed bool abnormalExit bool } @@ -61,6 +62,9 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error { if mb.abnormalExit { return packetIO.Close() } + if mb.proxyProtocol { + packetIO.ApplyOpts(pnet.WithProxy) + } var err error // write initial handshake if err = packetIO.WriteInitialHandshake(mb.capability, mb.salt, mb.authPlugin, pnet.ServerVersion, 100); err != nil { @@ -68,8 +72,9 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error { } // read the response var clientPkt []byte + // Unlike TiProxy, TiDB always sends an error to the client even if EOF. if clientPkt, err = packetIO.ReadPacket(); err != nil { - return err + return packetIO.WriteErrPacket(mysql.NewError(mysql.ER_UNKNOWN_ERROR, err.Error())) } // upgrade to TLS capability := binary.LittleEndian.Uint16(clientPkt[:2]) diff --git a/pkg/proxy/client/client_conn.go b/pkg/proxy/client/client_conn.go index bf7d54b2..1c5db433 100644 --- a/pkg/proxy/client/client_conn.go +++ b/pkg/proxy/client/client_conn.go @@ -32,7 +32,7 @@ func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *t } pkt := pnet.NewPacketIO(conn, logger, opts...) return &ClientConnection{ - logger: logger.With(zap.Bool("proxy-protocol", bcConfig.ProxyProtocol)), + logger: logger, frontendTLSConfig: frontendTLSConfig, backendTLSConfig: backendTLSConfig, pkt: pkt, @@ -58,7 +58,7 @@ clean: switch src { case backend.SrcClientQuit, backend.SrcClientErr, backend.SrcProxyQuit: default: - cc.logger.Info(msg, zap.String("backend_addr", cc.connMgr.ServerAddr()), zap.Stringer("quit source", src), zap.Error(err)) + cc.logger.Warn(msg, zap.String("backend_addr", cc.connMgr.ServerAddr()), zap.Stringer("quit source", src), zap.Error(err)) } } diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 5769893c..5b151698 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -43,7 +43,7 @@ import ( ) var ( - errInvalidSequence = dbterror.ClassServer.NewStd(errno.ErrInvalidSequence) + ErrInvalidSequence = dbterror.ClassServer.NewStd(errno.ErrInvalidSequence) ) const ( @@ -169,7 +169,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) { sequence := header[3] if sequence != p.sequence { - return nil, false, errInvalidSequence.GenWithStack("invalid sequence %d != %d", sequence, p.sequence) + return nil, false, ErrInvalidSequence.GenWithStack("invalid sequence %d != %d", sequence, p.sequence) } p.sequence++ length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 41e7308a..0a566384 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -145,7 +145,8 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) { connID := s.mu.connID s.mu.connID++ - logger := s.logger.With(zap.Uint64("connID", connID), zap.String("client_addr", conn.RemoteAddr().String())) + logger := s.logger.With(zap.Uint64("connID", connID), zap.String("client_addr", conn.RemoteAddr().String()), + zap.Bool("proxy-protocol", s.mu.proxyProtocol)) clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(), s.hsHandler, connID, &backend.BCConfig{ ProxyProtocol: s.mu.proxyProtocol,