From 8b4557c1afdc2bbabb544ca555fc12359e18519d Mon Sep 17 00:00:00 2001 From: xhe Date: Mon, 9 Jan 2023 14:17:55 +0800 Subject: [PATCH] backend: use BackendConnMgr as ConnContext (#172) --- pkg/manager/router/backend_selector.go | 4 +- pkg/manager/router/router_static.go | 8 +- pkg/proxy/backend/authenticator.go | 153 ++++++++------------- pkg/proxy/backend/authenticator_test.go | 50 +++---- pkg/proxy/backend/backend_conn.go | 63 --------- pkg/proxy/backend/backend_conn_mgr.go | 146 +++++++++++++------- pkg/proxy/backend/backend_conn_mgr_test.go | 59 ++++---- pkg/proxy/backend/cmd_processor_test.go | 4 +- pkg/proxy/backend/mock_backend_test.go | 14 +- pkg/proxy/backend/mock_client_test.go | 14 +- pkg/proxy/backend/mock_proxy_test.go | 11 +- pkg/proxy/backend/testsuite_test.go | 13 +- pkg/proxy/net/mysql.go | 4 +- pkg/proxy/net/packetio.go | 10 +- pkg/proxy/net/packetio_mysql.go | 2 +- pkg/proxy/net/packetio_options.go | 27 +++- pkg/proxy/net/proxy.go | 2 + 17 files changed, 279 insertions(+), 305 deletions(-) delete mode 100644 pkg/proxy/backend/backend_conn.go diff --git a/pkg/manager/router/backend_selector.go b/pkg/manager/router/backend_selector.go index ca7a4aee..7df65788 100644 --- a/pkg/manager/router/backend_selector.go +++ b/pkg/manager/router/backend_selector.go @@ -26,10 +26,10 @@ func (bs *BackendSelector) Reset() { } func (bs *BackendSelector) Next() string { - if len(bs.cur) > 0 { + bs.cur = bs.routeOnce(bs.excluded) + if bs.cur != "" { bs.excluded = append(bs.excluded, bs.cur) } - bs.cur = bs.routeOnce(bs.excluded) return bs.cur } diff --git a/pkg/manager/router/router_static.go b/pkg/manager/router/router_static.go index afcdfcbe..f7a9db49 100644 --- a/pkg/manager/router/router_static.go +++ b/pkg/manager/router/router_static.go @@ -17,18 +17,18 @@ package router var _ Router = &StaticRouter{} type StaticRouter struct { - addr []string - cnt int + addrs []string + cnt int } func NewStaticRouter(addr []string) *StaticRouter { - return &StaticRouter{addr: addr} + return &StaticRouter{addrs: addr} } func (r *StaticRouter) GetBackendSelector() BackendSelector { return BackendSelector{ routeOnce: func(excluded []string) string { - for _, addr := range r.addr { + for _, addr := range r.addrs { found := false for _, e := range excluded { if e == addr { diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index cb3e3711..dfc6cc25 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -19,7 +19,6 @@ import ( "encoding/binary" "fmt" "net" - "sync" "time" "github.com/pingcap/TiProxy/lib/util/errors" @@ -47,19 +46,14 @@ const SupportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRo // Authenticator handshakes with the client and the backend. type Authenticator struct { - backendTLSConfig *tls.Config - ctxmap sync.Map - supportedServerCapabilities pnet.Capability - dbname string // default database name - serverAddr string - clientAddr string - user string - attrs map[string]string - salt []byte - capability uint32 // client capability - collation uint8 - proxyProtocol bool - requireBackendTLS bool + dbname string // default database name + user string + attrs map[string]string + salt []byte + capability uint32 // client capability + collation uint8 + proxyProtocol bool + requireBackendTLS bool } func (auth *Authenticator) String() string { @@ -104,17 +98,16 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili type backendIOGetter func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp, timeout time.Duration) (*pnet.PacketIO, error) -func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler, +func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnContext, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler, getBackendIO backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error { clientIO.ResetSequence() - auth.clientAddr = clientIO.SourceAddr().String() - proxyCapability := auth.supportedServerCapabilities + proxyCapability := handshakeHandler.GetCapability() if frontendTLSConfig == nil { proxyCapability ^= pnet.ClientSSL } - if err := clientIO.WriteInitialHandshake(proxyCapability.Uint32(), auth.salt, mysql.AuthNativePassword); err != nil { + if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, mysql.AuthNativePassword); err != nil { return err } pkt, isSSL, err := clientIO.ReadSSLRequestOrHandshakeResp() @@ -151,17 +144,17 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet } auth.capability = commonCaps.Uint32() - resp := pnet.ParseHandshakeResponse(pkt) - if err = handshakeHandler.HandleHandshakeResp(auth, resp); err != nil { + clientResp := pnet.ParseHandshakeResponse(pkt) + if err = handshakeHandler.HandleHandshakeResp(cctx, clientResp); err != nil { return err } - auth.user = resp.User - auth.dbname = resp.DB - auth.collation = resp.Collation - auth.attrs = resp.Attrs + auth.user = clientResp.User + auth.dbname = clientResp.DB + auth.collation = clientResp.Collation + auth.attrs = clientResp.Attrs // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic. - backendIO, err := getBackendIO(auth, auth, resp, 5*time.Second) + backendIO, err := getBackendIO(cctx, auth, clientResp, 5*time.Second) if err != nil { return err } @@ -173,11 +166,10 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet } // read backend initial handshake - _, backendCapabilityU, err := auth.readInitialHandshake(backendIO) + _, backendCapability, err := auth.readInitialHandshake(backendIO) if err != nil { return err } - backendCapability := pnet.Capability(backendCapabilityU) if err := auth.verifyBackendCaps(logger, backendCapability); err != nil { return err @@ -195,38 +187,12 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet logger.Info("backend does not support capabilities from proxy", zap.Stringer("common", common), zap.Stringer("proxy", proxyCapability^common), zap.Stringer("backend", backendCapability^common)) } - // Send an unknown auth plugin so that the backend will request the auth data again. - resp.AuthPlugin = unknownAuthPlugin - resp.Capability = auth.capability - - if backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil { - resp.Capability |= mysql.ClientSSL - pkt = pnet.MakeHandshakeResponse(resp) - // write SSL Packet - if err := backendIO.WritePacket(pkt[:32], true); err != nil { - return err - } - auth.backendTLSConfig = backendTLSConfig.Clone() - addr := backendIO.RemoteAddr().String() - if auth.serverAddr != "" { - // NOTE: should use DNS name as much as possible - // Usually certs are signed with domain instead of IP addrs - // And `RemoteAddr()` will return IP addr - addr = auth.serverAddr - } - host, _, err := net.SplitHostPort(addr) - if err == nil { - auth.backendTLSConfig.ServerName = host - } - if err = backendIO.ClientTLSHandshake(auth.backendTLSConfig); err != nil { - return err - } - } else { - pkt = pnet.MakeHandshakeResponse(resp) - } - // forward client handshake resp - if err := backendIO.WritePacket(pkt, true); err != nil { + if err := auth.writeAuthHandshake( + backendIO, backendTLSConfig, backendCapability, + // send an unknown auth plugin so that the backend will request the auth data again. + unknownAuthPlugin, nil, 0, + ); err != nil { return err } @@ -258,7 +224,7 @@ func forwardMsg(srcIO, destIO *pnet.PacketIO) (data []byte, err error) { return } -func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, sessionToken string) error { +func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, backendTLSConfig *tls.Config, sessionToken string) error { if len(sessionToken) == 0 { return errors.New("session token is empty") } @@ -268,24 +234,26 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac return err } - _, serverCapability, err := auth.readInitialHandshake(backendIO) + _, backendCapability, err := auth.readInitialHandshake(backendIO) if err != nil { return err } - if err := auth.verifyBackendCaps(logger, pnet.Capability(serverCapability)); err != nil { + if err := auth.verifyBackendCaps(logger, pnet.Capability(backendCapability)); err != nil { return err } - tokenBytes := hack.Slice(sessionToken) - if err = auth.writeAuthHandshake(backendIO, tokenBytes, serverCapability&mysql.ClientSSL != 0); err != nil { + if err = auth.writeAuthHandshake( + backendIO, backendTLSConfig, backendCapability, + mysql.AuthTiDBSessionToken, hack.Slice(sessionToken), mysql.ClientPluginAuth, + ); err != nil { return err } return auth.handleSecondAuthResult(backendIO) } -func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability uint32, err error) { +func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) { if serverPkt, err = backendIO.ReadPacket(); err != nil { return } @@ -297,32 +265,49 @@ func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serve return } -func (auth *Authenticator) writeAuthHandshake(backendIO *pnet.PacketIO, authData []byte, tls bool) error { +func (auth *Authenticator) writeAuthHandshake( + backendIO *pnet.PacketIO, + backendTLSConfig *tls.Config, + backendCapability pnet.Capability, + authPlugin string, + authData []byte, + authCap uint32, +) error { // Always handshake with SSL enabled and enable auth_plugin. resp := &pnet.HandshakeResp{ User: auth.user, DB: auth.dbname, - AuthPlugin: mysql.AuthTiDBSessionToken, Attrs: auth.attrs, - AuthData: authData, - Capability: auth.capability | mysql.ClientSSL | mysql.ClientPluginAuth, Collation: auth.collation, + AuthData: authData, + Capability: auth.capability | authCap, + AuthPlugin: authPlugin, } - data := pnet.MakeHandshakeResponse(resp) - if tls && auth.backendTLSConfig != nil { - // write SSL req - if err := backendIO.WritePacket(data[:32], true); err != nil { + var pkt []byte + if backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil { + pkt = pnet.MakeHandshakeResponse(resp) + resp.Capability |= mysql.ClientSSL + // write SSL Packet + if err := backendIO.WritePacket(pkt[:32], true); err != nil { return err } // Send TLS / SSL request packet. The server must have supported TLS. - if err := backendIO.ClientTLSHandshake(auth.backendTLSConfig); err != nil { + tcfg := backendTLSConfig.Clone() + addr := backendIO.RemoteAddr().String() + host, _, err := net.SplitHostPort(addr) + if err == nil { + tcfg.ServerName = host + } + if err := backendIO.ClientTLSHandshake(tcfg); err != nil { return err } + } else { + pkt = pnet.MakeHandshakeResponse(resp) } // write handshake resp - return backendIO.WritePacket(data, true) + return backendIO.WritePacket(pkt, true) } func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) error { @@ -353,23 +338,3 @@ func (auth *Authenticator) changeUser(username, db string) { func (auth *Authenticator) updateCurrentDB(db string) { auth.dbname = db } - -func (auth *Authenticator) ClientAddr() string { - return auth.clientAddr -} - -func (auth *Authenticator) ServerAddr() string { - return auth.serverAddr -} - -func (auth *Authenticator) SetValue(key, val any) { - auth.ctxmap.Store(key, val) -} - -func (auth *Authenticator) Value(key any) any { - v, ok := auth.ctxmap.Load(key) - if !ok { - return nil - } - return v -} diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index 71a156cf..774c8493 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -15,7 +15,6 @@ package backend import ( - "net" "strings" "testing" @@ -28,50 +27,50 @@ func TestUnsupportedCapability(t *testing.T) { cfgs := [][]cfgOverrider{ { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^mysql.ClientSSL + cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientSSL }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | mysql.ClientSSL + cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientSSL }, }, { func(cfg *testConfig) { - cfg.backendConfig.capability = defaultTestBackendCapability & ^mysql.ClientSSL + cfg.backendConfig.capability = defaultTestBackendCapability & ^pnet.ClientSSL }, func(cfg *testConfig) { - cfg.backendConfig.capability = defaultTestBackendCapability | mysql.ClientSSL + cfg.backendConfig.capability = defaultTestBackendCapability | pnet.ClientSSL }, }, { func(cfg *testConfig) { - cfg.backendConfig.capability = defaultTestBackendCapability & ^mysql.ClientDeprecateEOF + cfg.backendConfig.capability = defaultTestBackendCapability & ^pnet.ClientDeprecateEOF }, func(cfg *testConfig) { - cfg.backendConfig.capability = defaultTestBackendCapability | mysql.ClientDeprecateEOF + cfg.backendConfig.capability = defaultTestBackendCapability | pnet.ClientDeprecateEOF }, }, { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^mysql.ClientProtocol41 + cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientProtocol41 }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | mysql.ClientProtocol41 + cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientProtocol41 }, }, { func(cfg *testConfig) { - cfg.backendConfig.capability = defaultTestClientCapability & ^mysql.ClientPSMultiResults + cfg.backendConfig.capability = defaultTestClientCapability & ^pnet.ClientPSMultiResults }, func(cfg *testConfig) { - cfg.backendConfig.capability = defaultTestClientCapability | mysql.ClientPSMultiResults + cfg.backendConfig.capability = defaultTestClientCapability | pnet.ClientPSMultiResults }, }, { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^mysql.ClientPSMultiResults + cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientPSMultiResults }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | mysql.ClientPSMultiResults + cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientPSMultiResults }, }, } @@ -81,9 +80,9 @@ func TestUnsupportedCapability(t *testing.T) { for _, cfgs := range cfgOverriders { ts, clean := newTestSuite(t, tc, cfgs...) ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite) { - if ts.mb.backendConfig.capability&defRequiredBackendCaps.Uint32() != defRequiredBackendCaps.Uint32() { + if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps { require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation) - } else if ts.mc.clientConfig.capability&requiredFrontendCaps.Uint32() != requiredFrontendCaps.Uint32() { + } else if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps { require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation) } else { require.NoError(t, ts.mc.err) @@ -107,10 +106,10 @@ func TestAuthPlugin(t *testing.T) { }, { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^mysql.ClientPluginAuth + cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientPluginAuth }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | mysql.ClientPluginAuth + cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientPluginAuth }, }, { @@ -152,27 +151,27 @@ func TestCapability(t *testing.T) { cfgs := [][]cfgOverrider{ { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^mysql.ClientConnectWithDB + cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientConnectWithDB }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | mysql.ClientConnectWithDB + cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientConnectWithDB }, }, { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^mysql.ClientConnectAtts + cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientConnectAttrs }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | mysql.ClientConnectAtts + cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientConnectAttrs cfg.clientConfig.attrs = map[string]string{"key": "value"} }, }, { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^mysql.ClientSecureConnection + cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientSecureConnection }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | mysql.ClientSecureConnection + cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientSecureConnection }, }, } @@ -216,12 +215,10 @@ func TestCustomAuth(t *testing.T) { reAttrs := map[string]string{"key": "value"} reCap := SupportedServerCapabilities & ^pnet.ClientDeprecateEOF inUser := "" - inAddr := "" ts, clean := newTestSuite(t, tc, func(cfg *testConfig) { handler := cfg.proxyConfig.handler handler.handleHandshakeResp = func(ctx ConnContext, resp *pnet.HandshakeResp) error { inUser = resp.User - inAddr = ctx.ClientAddr() resp.User = reUser resp.Attrs = reAttrs return nil @@ -235,9 +232,6 @@ func TestCustomAuth(t *testing.T) { require.Equal(t, reUser, ts.mb.username) require.Equal(t, reAttrs, ts.mb.attrs) require.Equal(t, reCap&pnet.ClientDeprecateEOF, pnet.Capability(ts.mb.capability)&pnet.ClientDeprecateEOF) - host, _, err := net.SplitHostPort(inAddr) - require.NoError(t, err) - require.Equal(t, host, "::1") } ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {}) checker() diff --git a/pkg/proxy/backend/backend_conn.go b/pkg/proxy/backend/backend_conn.go deleted file mode 100644 index 8fdcdb10..00000000 --- a/pkg/proxy/backend/backend_conn.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package backend - -import ( - "net" - "time" - - "github.com/pingcap/TiProxy/lib/util/errors" - pnet "github.com/pingcap/TiProxy/pkg/proxy/net" -) - -const ( - DialTimeout = 2 * time.Second -) - -type BackendConnection struct { - pkt *pnet.PacketIO // a helper to read and write data in packet format. - address string -} - -func NewBackendConnection(address string) *BackendConnection { - return &BackendConnection{ - address: address, - } -} - -func (bc *BackendConnection) Addr() string { - return bc.address -} - -func (bc *BackendConnection) Connect() error { - cn, err := net.DialTimeout("tcp", bc.address, DialTimeout) - if err != nil { - return errors.Wrapf(err, "dial backend error") - } - pkt := pnet.NewPacketIO(cn) - bc.pkt = pkt - return nil -} - -func (bc *BackendConnection) PacketIO() *pnet.PacketIO { - return bc.pkt -} - -func (bc *BackendConnection) Close() error { - if bc.pkt != nil { - return bc.pkt.Close() - } - return nil -} diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 7035c052..26dc7eb6 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -20,6 +20,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "net" "strings" "sync" "sync/atomic" @@ -41,6 +42,10 @@ var ( ErrCloseConnMgr = errors.New("failed to close connection manager") ) +const ( + DialTimeout = 5 * time.Second +) + const ( sqlQueryState = "SHOW SESSION_STATES" sqlSetState = "SET SESSION_STATES '%s'" @@ -101,8 +106,11 @@ type BackendConnManager struct { closeStatus atomic.Int32 // cancelFunc is used to cancel the signal processing goroutine. cancelFunc context.CancelFunc - backendConn *BackendConnection + backendIO *pnet.PacketIO + backendTLS *tls.Config handshakeHandler HandshakeHandler + ctxmap sync.Map + clientAddr string connectionID uint64 } @@ -115,10 +123,9 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler cmdProcessor: NewCmdProcessor(), handshakeHandler: handshakeHandler, authenticator: &Authenticator{ - supportedServerCapabilities: handshakeHandler.GetCapability(), - proxyProtocol: proxyProtocol, - requireBackendTLS: requireBackendTLS, - salt: GenerateSalt(20), + proxyProtocol: proxyProtocol, + requireBackendTLS: requireBackendTLS, + salt: GenerateSalt(20), }, // There are 2 types of signals, which may be sent concurrently. signalReceived: make(chan signalType, signalTypeNums), @@ -138,8 +145,11 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe mgr.processLock.Lock() defer mgr.processLock.Unlock() - err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) - mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, err) + mgr.backendTLS = backendTLSConfig + + mgr.clientAddr = clientIO.RemoteAddr().String() + err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), mgr, clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) + mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err) if err != nil { return err } @@ -153,8 +163,8 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe return nil } -func (mgr *BackendConnManager) getBackendIO(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp, timeout time.Duration) (*pnet.PacketIO, error) { - r, err := mgr.handshakeHandler.GetRouter(auth, resp) +func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp, timeout time.Duration) (*pnet.PacketIO, error) { + r, err := mgr.handshakeHandler.GetRouter(cctx, resp) if err != nil { return nil, err } @@ -163,35 +173,44 @@ func (mgr *BackendConnManager) getBackendIO(ctx ConnContext, auth *Authenticator // - One TiDB may be just shut down and another is just started but not ready yet bctx, cancel := context.WithTimeout(context.Background(), timeout) selector := r.GetBackendSelector() + var addr string io, err := backoff.RetryNotifyWithData( func() (*pnet.PacketIO, error) { // Try to connect to all backup backends one by one. - selector.Reset() - for { - addr := selector.Next() - if len(addr) == 0 { + addr := selector.Next() + + // if all addrs are enumerated, reset and try again + if addr == "" { + selector.Reset() + addr = selector.Next() + if addr == "" { return nil, router.ErrNoInstanceToSelect } - backendConn := NewBackendConnection(addr) - err := backendConn.Connect() - mgr.handshakeHandler.OnHandshake(auth, addr, err) - if err == nil { - if err = selector.Succeed(mgr); err == nil { - mgr.logger.Info("connected to backend", zap.String("addr", addr)) - mgr.backendConn = backendConn - auth.serverAddr = addr - return mgr.backendConn.PacketIO(), nil - } - // Bad luck: the backend has been recycled or shut down just after the selector returns it. - if ignoredErr := backendConn.Close(); ignoredErr != nil { - mgr.logger.Error("close backend connection failed", zap.String("addr", addr), zap.Error(ignoredErr)) - } + } + + cn, err := net.DialTimeout("tcp", addr, DialTimeout) + if err != nil { + return nil, errors.Wrapf(err, "dial backend %s error", addr) + } + + if err := selector.Succeed(mgr); err != nil { + // Bad luck: the backend has been recycled or shut down just after the selector returns it. + if ignoredErr := cn.Close(); ignoredErr != nil { + mgr.logger.Error("close backend connection failed", zap.String("addr", addr), zap.Error(ignoredErr)) } + return nil, err } + + mgr.logger.Info("connected to backend", zap.String("addr", addr)) + // NOTE: should use DNS name as much as possible + // Usually certs are signed with domain instead of IP addrs + // And `RemoteAddr()` will return IP addr + mgr.backendIO = pnet.NewPacketIO(cn, pnet.WithRemoteAddr(addr)) + return mgr.backendIO, nil }, backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), func(err error, d time.Duration) { - mgr.handshakeHandler.OnHandshake(auth, "", err) + mgr.handshakeHandler.OnHandshake(cctx, addr, err) }, ) cancel() @@ -214,9 +233,9 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c return nil } waitingRedirect := atomic.LoadPointer(&mgr.signal) != nil - holdRequest, err := mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendConn.PacketIO(), waitingRedirect) + holdRequest, err := mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendIO, waitingRedirect) if !holdRequest { - addCmdMetrics(cmd, mgr.backendConn.Addr(), startTime) + addCmdMetrics(cmd, mgr.ServerAddr(), startTime) } if err != nil { if !IsMySQLError(err) { @@ -252,8 +271,8 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c if waitingRedirect && holdRequest { mgr.tryRedirect(ctx, clientIO) // Execute the held request no matter redirection succeeds or not. - _, err = mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendConn.PacketIO(), false) - addCmdMetrics(cmd, mgr.backendConn.Addr(), startTime) + _, err = mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendIO, false) + addCmdMetrics(cmd, mgr.ServerAddr(), startTime) if err != nil && !IsMySQLError(err) { return err } @@ -293,7 +312,7 @@ func (mgr *BackendConnManager) initSessionStates(backendIO *pnet.PacketIO, sessi func (mgr *BackendConnManager) querySessionStates() (sessionStates, sessionToken string, err error) { // Do not lock here because the caller already locks. var result *gomysql.Result - if result, _, err = mgr.cmdProcessor.query(mgr.backendConn.PacketIO(), sqlQueryState); err != nil { + if result, _, err = mgr.cmdProcessor.query(mgr.backendIO, sqlQueryState); err != nil { return } if sessionStates, err = result.GetStringByName(0, sessionStatesCol); err != nil { @@ -343,7 +362,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context, clientIO *pnet.P } rs := &redirectResult{ - from: mgr.backendConn.Addr(), + from: mgr.ServerAddr(), to: signal.newAddr, } defer func() { @@ -362,29 +381,32 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context, clientIO *pnet.P return } - newConn := NewBackendConnection(rs.to) - if rs.err = newConn.Connect(); rs.err != nil { - mgr.handshakeHandler.OnHandshake(mgr.authenticator, rs.to, rs.err) + var cn net.Conn + cn, rs.err = net.DialTimeout("tcp", rs.to, DialTimeout) + if rs.err != nil { + mgr.handshakeHandler.OnHandshake(mgr, rs.to, rs.err) return } - mgr.authenticator.serverAddr = rs.to - mgr.authenticator.clientAddr = clientIO.SourceAddr().String() - if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, clientIO, newConn.PacketIO(), sessionToken); rs.err == nil { - rs.err = mgr.initSessionStates(newConn.PacketIO(), sessionStates) + newBackendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(rs.to)) + + mgr.clientAddr = clientIO.RemoteAddr().String() + if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil { + rs.err = mgr.initSessionStates(newBackendIO, sessionStates) } else { - mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, rs.err) + mgr.handshakeHandler.OnHandshake(mgr, newBackendIO.RemoteAddr().String(), rs.err) } if rs.err != nil { - if ignoredErr := newConn.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) { + if ignoredErr := newBackendIO.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) { mgr.logger.Error("close new backend connection failed", zap.Error(ignoredErr)) } return } - if ignoredErr := mgr.backendConn.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) { + if ignoredErr := mgr.backendIO.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) { mgr.logger.Error("close previous backend connection failed", zap.Error(ignoredErr)) } - mgr.backendConn = newConn - mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, nil) + + mgr.backendIO = newBackendIO + mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil) } // The original db in the auth info may be dropped during the session, so we need to authenticate with the current db. @@ -462,11 +484,31 @@ func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context, clientIO *p } // Closing clientIO will cause the whole connection to be closed. if err := clientIO.GracefulClose(); err != nil { - mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", clientIO.SourceAddr()), zap.Error(err)) + mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", clientIO.RemoteAddr()), zap.Error(err)) } mgr.closeStatus.Store(statusClosing) } +func (mgr *BackendConnManager) ClientAddr() string { + return mgr.clientAddr +} + +func (mgr *BackendConnManager) ServerAddr() string { + return mgr.backendIO.RemoteAddr().String() +} + +func (mgr *BackendConnManager) SetValue(key, val any) { + mgr.ctxmap.Store(key, val) +} + +func (mgr *BackendConnManager) Value(key any) any { + v, ok := mgr.ctxmap.Load(key) + if !ok { + return nil + } + return v +} + // Close releases all resources. func (mgr *BackendConnManager) Close() error { mgr.closeStatus.Store(statusClosing) @@ -479,14 +521,14 @@ func (mgr *BackendConnManager) Close() error { var connErr error var addr string mgr.processLock.Lock() - if mgr.backendConn != nil { - addr = mgr.backendConn.address - connErr = mgr.backendConn.Close() - mgr.backendConn = nil + if mgr.backendIO != nil { + addr = mgr.ServerAddr() + connErr = mgr.backendIO.Close() + mgr.backendIO = nil } mgr.processLock.Unlock() - handErr := mgr.handshakeHandler.OnConnClose(mgr.authenticator) + handErr := mgr.handshakeHandler.OnConnClose(mgr) eventReceiver := mgr.getEventReceiver() if eventReceiver != nil { diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index c8702008..fc74469f 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -16,6 +16,7 @@ package backend import ( "context" + "fmt" "net" "sync/atomic" "testing" @@ -180,32 +181,32 @@ func (ts *backendMgrTester) startTxn4Backend(packetIO *pnet.PacketIO) error { func (ts *backendMgrTester) checkNotRedirected4Proxy(clientIO, backendIO *pnet.PacketIO) error { signal := (*signalRedirect)(atomic.LoadPointer(&ts.mp.signal)) require.Nil(ts.t, signal) - backend1 := ts.mp.backendConn + backend1 := ts.mp.backendIO // There is no other way to verify it's not redirected. // The buffer size of channel signalReceived is 0, so after the second redirect signal is sent, // we can ensure that the first signal is already processed. ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.signalReceived <- signalTypeRedirect // The backend connection is still the same. - require.Equal(ts.t, backend1, ts.mp.backendConn) + require.Equal(ts.t, backend1, ts.mp.backendIO) return nil } func (ts *backendMgrTester) redirectAfterCmd4Proxy(clientIO, backendIO *pnet.PacketIO) error { - backend1 := ts.mp.backendConn + backend1 := ts.mp.backendIO err := ts.forwardCmd4Proxy(clientIO, backendIO) require.NoError(ts.t, err) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventSucceed) - require.NotEqual(ts.t, backend1, ts.mp.backendConn) + require.NotEqual(ts.t, backend1, ts.mp.backendIO) require.Len(ts.t, ts.mp.GetRedirectingAddr(), 0) return nil } func (ts *backendMgrTester) redirectFail4Proxy(clientIO, backendIO *pnet.PacketIO) error { - backend1 := ts.mp.backendConn + backend1 := ts.mp.backendIO ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventFail) - require.Equal(ts.t, backend1, ts.mp.backendConn) + require.Equal(ts.t, backend1, ts.mp.backendIO) require.Len(ts.t, ts.mp.GetRedirectingAddr(), 0) return nil } @@ -241,10 +242,10 @@ func TestNormalRedirect(t *testing.T) { { client: nil, proxy: func(_, _ *pnet.PacketIO) error { - backend1 := ts.mp.backendConn + backend1 := ts.mp.backendIO ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed) - require.NotEqual(t, backend1, ts.mp.backendConn) + require.NotEqual(t, backend1, ts.mp.backendIO) return nil }, backend: ts.redirectSucceed4Backend, @@ -344,11 +345,11 @@ func TestRedirectInTxn(t *testing.T) { return ts.mc.request(packetIO) }, proxy: func(clientIO, backendIO *pnet.PacketIO) error { - backend1 := ts.mp.backendConn + backend1 := ts.mp.backendIO err := ts.forwardCmd4Proxy(clientIO, backendIO) require.NoError(t, err) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventFail) - require.Equal(t, backend1, ts.mp.backendConn) + require.Equal(t, backend1, ts.mp.backendIO) return nil }, backend: func(packetIO *pnet.PacketIO) error { @@ -492,10 +493,10 @@ func TestSpecialCmds(t *testing.T) { { client: nil, proxy: func(_, _ *pnet.PacketIO) error { - backend1 := ts.mp.backendConn + backend1 := ts.mp.backendIO ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed) - require.NotEqual(t, backend1, ts.mp.backendConn) + require.NotEqual(t, backend1, ts.mp.backendIO) return nil }, backend: func(packetIO *pnet.PacketIO) error { @@ -503,8 +504,8 @@ func TestSpecialCmds(t *testing.T) { require.NoError(t, ts.redirectSucceed4Backend(packetIO)) require.Equal(t, "another_user", ts.mb.username) require.Equal(t, "session_db", ts.mb.db) - expectCap := pnet.Capability(ts.mp.authenticator.supportedServerCapabilities.Uint32() &^ (mysql.ClientMultiStatements | mysql.ClientPluginAuthLenencClientData)) - gotCap := pnet.Capability(ts.mb.capability &^ mysql.ClientPluginAuthLenencClientData) + expectCap := pnet.Capability(ts.mp.handshakeHandler.GetCapability() &^ (pnet.ClientMultiStatements | pnet.ClientPluginAuthLenencClientData)) + gotCap := pnet.Capability(ts.mb.capability &^ pnet.ClientPluginAuthLenencClientData) require.Equal(t, expectCap, gotCap, "expected=%s,got=%s", expectCap, gotCap) return nil }, @@ -588,10 +589,10 @@ func TestCustomHandshake(t *testing.T) { { client: nil, proxy: func(_, _ *pnet.PacketIO) error { - backend1 := ts.mp.backendConn + backend1 := ts.mp.backendIO ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed) - require.NotEqual(t, backend1, ts.mp.backendConn) + require.NotEqual(t, backend1, ts.mp.backendIO) return nil }, backend: ts.redirectSucceed4Backend, @@ -687,14 +688,16 @@ func TestGracefulCloseBeforeHandshake(t *testing.T) { } func TestGetBackendIO(t *testing.T) { - listeners := make([]net.Listener, 0, 3) addrs := make([]string, 0, 3) - for i := 0; i < 3; i++ { + listeners := make([]net.Listener, 0, cap(addrs)) + + for i := 0; i < cap(addrs); i++ { listener, err := net.Listen("tcp", "0.0.0.0:0") require.NoError(t, err) listeners = append(listeners, listener) addrs = append(addrs, listener.Addr().String()) } + rt := router.NewStaticRouter(addrs) badAddrs := make(map[string]struct{}, 3) handler := &CustomHandshakeHandler{ @@ -708,19 +711,29 @@ func TestGetBackendIO(t *testing.T) { }, } mgr := NewBackendConnManager(logger.CreateLoggerForTest(t), handler, 0, false, false) + var wg waitgroup.WaitGroup for i := 0; i <= len(listeners); i++ { - io, err := mgr.getBackendIO(mgr.authenticator, mgr.authenticator, nil, time.Second) + wg.Run(func() { + if i < len(listeners) { + cn, err := listeners[i].Accept() + require.NoError(t, err) + require.NoError(t, cn.Close()) + } + }) + io, err := mgr.getBackendIO(mgr, mgr.authenticator, nil, time.Second) if err == nil { require.NoError(t, io.Close()) } + message := fmt.Sprintf("%d: %s, %+v\n", i, badAddrs, err) if i < len(listeners) { - require.NoError(t, err) + require.NoError(t, err, message) err = listeners[i].Close() - require.NoError(t, err) + require.NoError(t, err, message) } else { - require.ErrorIs(t, err, context.DeadlineExceeded) + require.ErrorIs(t, err, context.DeadlineExceeded, message) } - require.True(t, len(badAddrs) <= i) + require.True(t, len(badAddrs) <= i, message) badAddrs = make(map[string]struct{}, 3) + wg.Wait() } } diff --git a/pkg/proxy/backend/cmd_processor_test.go b/pkg/proxy/backend/cmd_processor_test.go index 36f1fefd..c5a8bd02 100644 --- a/pkg/proxy/backend/cmd_processor_test.go +++ b/pkg/proxy/backend/cmd_processor_test.go @@ -86,7 +86,7 @@ func TestForwardCommands(t *testing.T) { // Test every respond type for every command. for cmd, respondTypes := range cmdResponseTypes { for _, respondType := range respondTypes { - for _, capability := range []uint32{defaultTestBackendCapability &^ mysql.ClientDeprecateEOF, defaultTestBackendCapability | mysql.ClientDeprecateEOF} { + for _, capability := range []pnet.Capability{defaultTestBackendCapability &^ pnet.ClientDeprecateEOF, defaultTestBackendCapability | pnet.ClientDeprecateEOF} { cfgOvr := func(cfg *testConfig) { cfg.clientConfig.cmd = cmd cfg.backendConfig.respondType = respondType @@ -161,7 +161,7 @@ func TestDirectQuery(t *testing.T) { }, { cfg: func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestBackendCapability &^ mysql.ClientDeprecateEOF + cfg.clientConfig.capability = defaultTestBackendCapability &^ pnet.ClientDeprecateEOF cfg.proxyConfig.capability = cfg.clientConfig.capability cfg.backendConfig.capability = cfg.clientConfig.capability cfg.backendConfig.columns = 2 diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index f29177e7..7abcd4af 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -34,7 +34,7 @@ type backendConfig struct { rows int respondType respondType stmtNum int - capability uint32 + capability pnet.Capability status uint16 authSucceed bool abnormalExit bool @@ -85,7 +85,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error { } // upgrade to TLS capability := binary.LittleEndian.Uint16(clientPkt[:2]) - sslEnabled := uint32(capability)&mysql.ClientSSL > 0 && mb.capability&mysql.ClientSSL > 0 + sslEnabled := uint32(capability)&mysql.ClientSSL > 0 && mb.capability&pnet.ClientSSL > 0 if sslEnabled { if _, err = packetIO.ServerTLSHandshake(mb.tlsConfig); err != nil { return err @@ -100,7 +100,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error { mb.db = resp.DB mb.authData = resp.AuthData mb.attrs = resp.Attrs - mb.capability = resp.Capability + mb.capability = pnet.Capability(resp.Capability) // verify password return mb.verifyPassword(packetIO, resp) } @@ -216,7 +216,7 @@ func (mb *mockBackend) respondColumns(packetIO *pnet.PacketIO) error { } func (mb *mockBackend) writeResultEndPacket(packetIO *pnet.PacketIO, status uint16) error { - if mb.capability&mysql.ClientDeprecateEOF > 0 { + if mb.capability&pnet.ClientDeprecateEOF > 0 { return packetIO.WriteOKPacket(status, mysql.EOFHeader) } return packetIO.WriteEOFPacket(status) @@ -272,7 +272,7 @@ func (mb *mockBackend) writeResultSet(packetIO *pnet.PacketIO, names []string, v } if status&mysql.ServerStatusCursorExists == 0 { - if mb.capability&mysql.ClientDeprecateEOF == 0 { + if mb.capability&pnet.ClientDeprecateEOF == 0 { if err := packetIO.WriteEOFPacket(status); err != nil { return err } @@ -344,7 +344,7 @@ func (mb *mockBackend) respondPrepare(packetIO *pnet.PacketIO) error { return err } } - if mb.capability&mysql.ClientDeprecateEOF == 0 { + if mb.capability&pnet.ClientDeprecateEOF == 0 { if err := packetIO.WriteEOFPacket(mb.status); err != nil { return err } @@ -356,7 +356,7 @@ func (mb *mockBackend) respondPrepare(packetIO *pnet.PacketIO) error { return err } } - if mb.capability&mysql.ClientDeprecateEOF == 0 { + if mb.capability&pnet.ClientDeprecateEOF == 0 { if err := packetIO.WriteEOFPacket(mb.status); err != nil { return err } diff --git a/pkg/proxy/backend/mock_client_test.go b/pkg/proxy/backend/mock_client_test.go index bb28cb64..322f4675 100644 --- a/pkg/proxy/backend/mock_client_test.go +++ b/pkg/proxy/backend/mock_client_test.go @@ -34,7 +34,7 @@ type clientConfig struct { authData []byte filePkts int prepStmtID int - capability uint32 + capability pnet.Capability collation uint8 // for cmd cmd byte @@ -87,11 +87,11 @@ func (mc *mockClient) authenticate(packetIO *pnet.PacketIO) error { AuthPlugin: mc.authPlugin, Attrs: mc.attrs, AuthData: mc.authData, - Capability: mc.capability, + Capability: mc.capability.Uint32(), Collation: mc.collation, } pkt = pnet.MakeHandshakeResponse(resp) - if mc.capability&mysql.ClientSSL > 0 { + if mc.capability&pnet.ClientSSL > 0 { if err := packetIO.WritePacket(pkt[:32], true); err != nil { return err } @@ -208,7 +208,7 @@ func (mc *mockClient) requestPrepare(packetIO *pnet.PacketIO) error { numColumns := binary.LittleEndian.Uint16(response[5:]) numParams := binary.LittleEndian.Uint16(response[7:]) expectedPacketNum = int(numColumns) + int(numParams) - if mc.capability&mysql.ClientDeprecateEOF == 0 { + if mc.capability&pnet.ClientDeprecateEOF == 0 { if numColumns > 0 { expectedPacketNum++ } @@ -270,7 +270,7 @@ func (mc *mockClient) readUntilResultEnd(packetIO *pnet.PacketIO) (pkt []byte, e if pkt[0] == mysql.ErrHeader { return } - if mc.capability&mysql.ClientDeprecateEOF == 0 { + if mc.capability&pnet.ClientDeprecateEOF == 0 { if pnet.IsEOFPacket(pkt) { break } @@ -331,7 +331,7 @@ func (mc *mockClient) readResultSet(packetIO *pnet.PacketIO) error { } default: // read result set - if mc.capability&mysql.ClientDeprecateEOF == 0 { + if mc.capability&pnet.ClientDeprecateEOF == 0 { if pkt, err = mc.readUntilResultEnd(packetIO); err != nil { return err } @@ -349,7 +349,7 @@ func (mc *mockClient) readResultSet(packetIO *pnet.PacketIO) error { if pkt[0] == mysql.ErrHeader { return nil } - if mc.capability&mysql.ClientDeprecateEOF == 0 { + if mc.capability&pnet.ClientDeprecateEOF == 0 { serverStatus = binary.LittleEndian.Uint16(pkt[3:]) } else { rs := pnet.ParseOKPacket(pkt) diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 52f6dee0..2bfd6901 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -30,7 +30,7 @@ type proxyConfig struct { backendTLSConfig *tls.Config handler *CustomHandshakeHandler sessionToken string - capability uint32 + capability pnet.Capability waitRedirect bool } @@ -43,8 +43,9 @@ func newProxyConfig() *proxyConfig { } type mockProxy struct { - *proxyConfig *BackendConnManager + + *proxyConfig // outputs that received from the server. rs *gomysql.Result // execution results @@ -59,12 +60,12 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { logger: logger.CreateLoggerForTest(t).Named("mockProxy"), BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), cfg.handler, 0, false, false), } - mp.cmdProcessor.capability = cfg.capability + mp.cmdProcessor.capability = cfg.capability.Uint32() return mp } func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error { - if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(ConnContext, *Authenticator, *pnet.HandshakeResp, time.Duration) (*pnet.PacketIO, error) { + if err := mp.authenticator.handshakeFirstTime(mp.logger, mp, clientIO, mp.handshakeHandler, func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp, timeout time.Duration) (*pnet.PacketIO, error) { return backendIO, nil }, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { return err @@ -74,7 +75,7 @@ func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) e } func (mp *mockProxy) authenticateSecondTime(clientIO, backendIO *pnet.PacketIO) error { - return mp.authenticator.handshakeSecondTime(mp.logger, clientIO, backendIO, mp.sessionToken) + return mp.authenticator.handshakeSecondTime(mp.logger, clientIO, backendIO, mp.backendTLSConfig, mp.sessionToken) } func (mp *mockProxy) processCmd(clientIO, backendIO *pnet.PacketIO) error { diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index 71555baa..01621fe2 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/TiProxy/pkg/manager/router" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" - "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" ) @@ -34,12 +33,12 @@ import ( // sent from the server and vice versa. const ( - defaultTestBackendCapability = mysql.ClientLongPassword | mysql.ClientFoundRows | mysql.ClientLongFlag | - mysql.ClientConnectWithDB | mysql.ClientNoSchema | mysql.ClientODBC | mysql.ClientLocalFiles | mysql.ClientIgnoreSpace | - mysql.ClientProtocol41 | mysql.ClientInteractive | mysql.ClientSSL | mysql.ClientIgnoreSigpipe | - mysql.ClientTransactions | mysql.ClientReserved | mysql.ClientSecureConnection | mysql.ClientMultiStatements | - mysql.ClientMultiResults | mysql.ClientPluginAuth | mysql.ClientConnectAtts | mysql.ClientPluginAuthLenencClientData | - mysql.ClientDeprecateEOF + defaultTestBackendCapability = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientLongFlag | + pnet.ClientConnectWithDB | pnet.ClientNoSchema | pnet.ClientODBC | pnet.ClientLocalFiles | pnet.ClientIgnoreSpace | + pnet.ClientProtocol41 | pnet.ClientInteractive | pnet.ClientSSL | pnet.ClientIgnoreSigpipe | + pnet.ClientTransactions | pnet.ClientReserved | pnet.ClientSecureConnection | pnet.ClientMultiStatements | + pnet.ClientMultiResults | pnet.ClientPluginAuth | pnet.ClientConnectAttrs | pnet.ClientPluginAuthLenencClientData | + pnet.ClientDeprecateEOF defaultTestClientCapability = defaultTestBackendCapability ) diff --git a/pkg/proxy/net/mysql.go b/pkg/proxy/net/mysql.go index ac58cd66..eaebf173 100644 --- a/pkg/proxy/net/mysql.go +++ b/pkg/proxy/net/mysql.go @@ -36,7 +36,7 @@ var ( ) // ParseInitialHandshake parses the initial handshake received from the server. -func ParseInitialHandshake(data []byte) uint32 { +func ParseInitialHandshake(data []byte) Capability { // skip mysql version pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 // skip connection id @@ -59,7 +59,7 @@ func ParseInitialHandshake(data []byte) uint32 { // skip salt second part // skip auth plugin } - return capability + return Capability(capability) } // HandshakeResp indicates the response read from the client. diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 0251a6cb..bedf3fcf 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -78,6 +78,7 @@ type PacketIO struct { buf *bufio.ReadWriter proxyInited *atomic.Bool proxy *Proxy + remoteAddr net.Addr wrap error sequence uint8 } @@ -120,13 +121,8 @@ func (p *PacketIO) LocalAddr() net.Addr { } func (p *PacketIO) RemoteAddr() net.Addr { - return p.conn.RemoteAddr() -} - -// SourceAddr returns the source address if proxy protocol is enabled. -func (p *PacketIO) SourceAddr() net.Addr { - if proxy := p.Proxy(); proxy != nil { - return proxy.SrcAddress + if p.remoteAddr != nil { + return p.remoteAddr } return p.conn.RemoteAddr() } diff --git a/pkg/proxy/net/packetio_mysql.go b/pkg/proxy/net/packetio_mysql.go index 5a63a68d..dbd57a9f 100644 --- a/pkg/proxy/net/packetio_mysql.go +++ b/pkg/proxy/net/packetio_mysql.go @@ -27,7 +27,7 @@ var ( // WriteInitialHandshake writes an initial handshake as a server. // It's used for tenant-aware routing and testing. -func (p *PacketIO) WriteInitialHandshake(capability uint32, salt []byte, authPlugin string) error { +func (p *PacketIO) WriteInitialHandshake(capability Capability, salt []byte, authPlugin string) error { saltLen := len(salt) if saltLen < 8 { return ErrSaltNotLongEnough diff --git a/pkg/proxy/net/packetio_options.go b/pkg/proxy/net/packetio_options.go index 73d8968b..831234e1 100644 --- a/pkg/proxy/net/packetio_options.go +++ b/pkg/proxy/net/packetio_options.go @@ -14,7 +14,11 @@ package net -import "go.uber.org/atomic" +import ( + "net" + + "go.uber.org/atomic" +) type PacketIOption = func(*PacketIO) @@ -27,3 +31,24 @@ func WithWrapError(err error) func(pi *PacketIO) { pi.wrap = err } } + +// WithRemoteAddr +var _ net.Addr = &oriRemoteAddr{} + +type oriRemoteAddr struct { + addr string +} + +func (o *oriRemoteAddr) Network() string { + return "tcp" +} + +func (o *oriRemoteAddr) String() string { + return o.addr +} + +func WithRemoteAddr(readdr string) func(pi *PacketIO) { + return func(pi *PacketIO) { + pi.remoteAddr = &oriRemoteAddr{addr: readdr} + } +} diff --git a/pkg/proxy/net/proxy.go b/pkg/proxy/net/proxy.go index 2cde7d70..cc228b7d 100644 --- a/pkg/proxy/net/proxy.go +++ b/pkg/proxy/net/proxy.go @@ -271,6 +271,8 @@ func (p *PacketIO) parseProxyV2() (*Proxy, error) { buf = buf[3+length:] } + // set RemoteAddr in case of proxy. + p.remoteAddr = m.SrcAddress return m, nil }