diff --git a/pkg/manager/namespace/manager.go b/pkg/manager/namespace/manager.go index 025fcbc1..e979de62 100644 --- a/pkg/manager/namespace/manager.go +++ b/pkg/manager/namespace/manager.go @@ -41,7 +41,7 @@ func NewNamespaceManager() *NamespaceManager { } func (mgr *NamespaceManager) buildNamespace(cfg *config.Namespace, client *clientv3.Client) (*Namespace, error) { logger := mgr.logger.With(zap.String("namespace", cfg.Namespace)) - rt, err := router.NewRandomRouter(&cfg.Backend, client) + rt, err := router.NewScoreBasedRouter(&cfg.Backend, client) if err != nil { return nil, errors.Errorf("build router error: %w", err) } diff --git a/pkg/manager/router/router.go b/pkg/manager/router/router.go index 7055374c..ab3ec2b0 100644 --- a/pkg/manager/router/router.go +++ b/pkg/manager/router/router.go @@ -27,6 +27,7 @@ import ( "go.uber.org/zap" ) +// Router routes client connections to backends. type Router interface { Route(RedirectableConn) (string, error) RedirectConnections() error @@ -50,45 +51,55 @@ const ( rebalanceMaxScoreRatio = 1.1 ) +// ConnEventReceiver receives connection events. type ConnEventReceiver interface { OnRedirectSucceed(from, to string, conn RedirectableConn) OnRedirectFail(from, to string, conn RedirectableConn) OnConnClosed(addr string, conn RedirectableConn) } +// RedirectableConn indicates a redirect-able connection. type RedirectableConn interface { SetEventReceiver(receiver ConnEventReceiver) Redirect(addr string) + GetRedirectingAddr() string ConnectionID() uint64 } -type BackendWrapper struct { +// backendWrapper contains the connections on the backend. +type backendWrapper struct { status BackendStatus addr string - // A list of *ConnWrapper and is ordered by the connecting or redirecting time. + // A list of *connWrapper and is ordered by the connecting or redirecting time. + // connList and connMap include moving out connections but not moving in connections. connList *list.List connMap map[uint64]*list.Element } -func (b *BackendWrapper) score() int { +// score calculates the score of the backend. Larger score indicates higher load. +func (b *backendWrapper) score() int { return b.status.ToScore() + b.connList.Len() } -type ConnWrapper struct { +// connWrapper wraps RedirectableConn. +type connWrapper struct { RedirectableConn phase int } -type RandomRouter struct { +// ScoreBasedRouter is an implementation of Router interface. +// It routes a connection based on score. +type ScoreBasedRouter struct { sync.Mutex observer *BackendObserver cancelFunc context.CancelFunc - // A list of *BackendWrapper and ordered by the score of the backends. + // A list of *backendWrapper. The backends are in descending order of scores. backends *list.List } -func NewRandomRouter(cfg *config.BackendNamespace, client *clientv3.Client) (*RandomRouter, error) { - router := &RandomRouter{ +// NewScoreBasedRouter creates a ScoreBasedRouter. +func NewScoreBasedRouter(cfg *config.BackendNamespace, client *clientv3.Client) (*ScoreBasedRouter, error) { + router := &ScoreBasedRouter{ backends: list.New(), } router.Lock() @@ -104,19 +115,20 @@ func NewRandomRouter(cfg *config.BackendNamespace, client *clientv3.Client) (*Ra return router, err } -func (router *RandomRouter) Route(conn RedirectableConn) (string, error) { +// Route implements Router.Route interface. +func (router *ScoreBasedRouter) Route(conn RedirectableConn) (string, error) { router.Lock() defer router.Unlock() be := router.backends.Back() if be == nil { return "", ErrNoInstanceToSelect } - backend := be.Value.(*BackendWrapper) + backend := be.Value.(*backendWrapper) switch backend.status { case StatusCannotConnect, StatusSchemaOutdated: return "", ErrNoInstanceToSelect } - connWrapper := &ConnWrapper{ + connWrapper := &connWrapper{ RedirectableConn: conn, phase: phaseNotRedirected, } @@ -125,27 +137,28 @@ func (router *RandomRouter) Route(conn RedirectableConn) (string, error) { return backend.addr, nil } -func (router *RandomRouter) removeConn(be *list.Element, ce *list.Element) { - backend := be.Value.(*BackendWrapper) - conn := ce.Value.(*ConnWrapper) +func (router *ScoreBasedRouter) removeConn(be *list.Element, ce *list.Element) { + backend := be.Value.(*backendWrapper) + conn := ce.Value.(*connWrapper) backend.connList.Remove(ce) delete(backend.connMap, conn.ConnectionID()) router.adjustBackendList(be) } -func (router *RandomRouter) addConn(be *list.Element, conn *ConnWrapper) { - backend := be.Value.(*BackendWrapper) +func (router *ScoreBasedRouter) addConn(be *list.Element, conn *connWrapper) { + backend := be.Value.(*backendWrapper) ce := backend.connList.PushBack(conn) backend.connMap[conn.ConnectionID()] = ce router.adjustBackendList(be) } -func (router *RandomRouter) adjustBackendList(be *list.Element) { - backend := be.Value.(*BackendWrapper) +// adjustBackendList moves `be` after the score of `be` changes to keep the list ordered. +func (router *ScoreBasedRouter) adjustBackendList(be *list.Element) { + backend := be.Value.(*backendWrapper) curScore := backend.score() var mark *list.Element for ele := be.Prev(); ele != nil; ele = ele.Prev() { - b := ele.Value.(*BackendWrapper) + b := ele.Value.(*backendWrapper) if b.score() >= curScore { break } @@ -156,7 +169,7 @@ func (router *RandomRouter) adjustBackendList(be *list.Element) { return } for ele := be.Next(); ele != nil; ele = ele.Next() { - b := ele.Value.(*BackendWrapper) + b := ele.Value.(*backendWrapper) if b.score() <= curScore { break } @@ -167,14 +180,16 @@ func (router *RandomRouter) adjustBackendList(be *list.Element) { } } -func (router *RandomRouter) RedirectConnections() error { +// RedirectConnections implements Router.RedirectConnections interface. +// It redirects all connections compulsively. It's only used for testing. +func (router *ScoreBasedRouter) RedirectConnections() error { router.Lock() defer router.Unlock() for be := router.backends.Front(); be != nil; be = be.Next() { - backend := be.Value.(*BackendWrapper) + backend := be.Value.(*backendWrapper) for ce := backend.connList.Front(); ce != nil; ce = ce.Next() { // This is only for test, so we allow it to reconnect to the same backend. - connWrapper := ce.Value.(*ConnWrapper) + connWrapper := ce.Value.(*connWrapper) if connWrapper.phase != phaseRedirectNotify { connWrapper.phase = phaseRedirectNotify connWrapper.Redirect(backend.addr) @@ -184,17 +199,18 @@ func (router *RandomRouter) RedirectConnections() error { return nil } -func (router *RandomRouter) lookupBackend(addr string, forward bool) *list.Element { +// forward is a hint to speed up searching. +func (router *ScoreBasedRouter) lookupBackend(addr string, forward bool) *list.Element { if forward { for be := router.backends.Front(); be != nil; be = be.Next() { - backend := be.Value.(*BackendWrapper) + backend := be.Value.(*backendWrapper) if backend.addr == addr { return be } } } else { for be := router.backends.Back(); be != nil; be = be.Prev() { - backend := be.Value.(*BackendWrapper) + backend := be.Value.(*backendWrapper) if backend.addr == addr { return be } @@ -203,36 +219,40 @@ func (router *RandomRouter) lookupBackend(addr string, forward bool) *list.Eleme return nil } -func (router *RandomRouter) OnRedirectSucceed(from, to string, conn RedirectableConn) { +// OnRedirectSucceed implements ConnEventReceiver.OnRedirectSucceed interface. +func (router *ScoreBasedRouter) OnRedirectSucceed(from, to string, conn RedirectableConn) { router.Lock() defer router.Unlock() be := router.lookupBackend(to, false) if be == nil { - // impossible here + logutil.BgLogger().Error("backend not found in the backend", zap.String("addr", to)) return } - toBackend := be.Value.(*BackendWrapper) + toBackend := be.Value.(*backendWrapper) e, ok := toBackend.connMap[conn.ConnectionID()] if !ok { - // impossible here + logutil.BgLogger().Error("connection not found in the backend", zap.String("addr", to), + zap.Uint64("conn", conn.ConnectionID())) return } - connWrapper := e.Value.(*ConnWrapper) + connWrapper := e.Value.(*connWrapper) connWrapper.phase = phaseRedirectEnd } -func (router *RandomRouter) OnRedirectFail(from, to string, conn RedirectableConn) { +// OnRedirectFail implements ConnEventReceiver.OnRedirectFail interface. +func (router *ScoreBasedRouter) OnRedirectFail(from, to string, conn RedirectableConn) { router.Lock() defer router.Unlock() be := router.lookupBackend(to, false) if be == nil { - // impossible here + logutil.BgLogger().Error("backend not found in the backend", zap.String("addr", to)) return } - toBackend := be.Value.(*BackendWrapper) + toBackend := be.Value.(*backendWrapper) ce, ok := toBackend.connMap[conn.ConnectionID()] if !ok { - // impossible here + logutil.BgLogger().Error("connection not found in the backend", zap.String("addr", to), + zap.Uint64("conn", conn.ConnectionID())) return } router.removeConn(be, ce) @@ -242,31 +262,39 @@ func (router *RandomRouter) OnRedirectFail(from, to string, conn RedirectableCon if be == nil { return } - connWrapper := ce.Value.(*ConnWrapper) + connWrapper := ce.Value.(*connWrapper) connWrapper.phase = phaseRedirectFail router.addConn(be, connWrapper) } -func (router *RandomRouter) OnConnClosed(addr string, conn RedirectableConn) { - connID := conn.ConnectionID() +// OnConnClosed implements ConnEventReceiver.OnConnClosed interface. +func (router *ScoreBasedRouter) OnConnClosed(addr string, conn RedirectableConn) { router.Lock() defer router.Unlock() + // Get the redirecting address in the lock, rather than letting the connection pass it in. + // While the connection closes, the router may also send a new redirection signal concurrently + // and move it to another backendWrapper. + if toAddr := conn.GetRedirectingAddr(); len(toAddr) > 0 { + addr = toAddr + } be := router.lookupBackend(addr, true) - if be != nil { - // impossible here + if be == nil { + logutil.BgLogger().Error("backend not found in the router", zap.String("addr", addr)) return } - backend := be.Value.(*BackendWrapper) - ce, ok := backend.connMap[connID] + backend := be.Value.(*backendWrapper) + ce, ok := backend.connMap[conn.ConnectionID()] if !ok { - // impossible here + logutil.BgLogger().Error("connection not found in the backend", zap.String("addr", addr), + zap.Uint64("conn", conn.ConnectionID())) return } router.removeConn(be, ce) router.removeBackendIfEmpty(be) } -func (router *RandomRouter) OnBackendChanged(backends map[string]BackendStatus) { +// OnBackendChanged implements BackendEventReceiver.OnBackendChanged interface. +func (router *ScoreBasedRouter) OnBackendChanged(backends map[string]BackendStatus) { router.Lock() defer router.Unlock() for addr, status := range backends { @@ -274,14 +302,14 @@ func (router *RandomRouter) OnBackendChanged(backends map[string]BackendStatus) if be == nil { logutil.BgLogger().Info("find new backend", zap.String("url", addr), zap.String("status", status.String())) - be = router.backends.PushBack(&BackendWrapper{ + be = router.backends.PushBack(&backendWrapper{ status: status, addr: addr, connList: list.New(), connMap: make(map[uint64]*list.Element), }) } else { - backend := be.Value.(*BackendWrapper) + backend := be.Value.(*backendWrapper) logutil.BgLogger().Info("update backend", zap.String("url", addr), zap.String("prev_status", backend.status.String()), zap.String("cur_status", status.String())) backend.status = status @@ -291,7 +319,7 @@ func (router *RandomRouter) OnBackendChanged(backends map[string]BackendStatus) } } -func (router *RandomRouter) rebalanceLoop(ctx context.Context) { +func (router *ScoreBasedRouter) rebalanceLoop(ctx context.Context) { for { router.rebalance(rebalanceConnsPerLoop) select { @@ -302,13 +330,13 @@ func (router *RandomRouter) rebalanceLoop(ctx context.Context) { } } -func (router *RandomRouter) rebalance(maxNum int) { +func (router *ScoreBasedRouter) rebalance(maxNum int) { router.Lock() defer router.Unlock() for i := 0; i < maxNum; i++ { var busiestEle *list.Element for be := router.backends.Front(); be != nil; be = be.Next() { - backend := be.Value.(*BackendWrapper) + backend := be.Value.(*backendWrapper) if backend.connList.Len() > 0 { busiestEle = be break @@ -317,29 +345,30 @@ func (router *RandomRouter) rebalance(maxNum int) { if busiestEle == nil { break } - busiestBackend := busiestEle.Value.(*BackendWrapper) + busiestBackend := busiestEle.Value.(*backendWrapper) idlestEle := router.backends.Back() - idlestBackend := idlestEle.Value.(*BackendWrapper) + idlestBackend := idlestEle.Value.(*backendWrapper) if float64(busiestBackend.score())/float64(idlestBackend.score()+1) <= rebalanceMaxScoreRatio { break } ce := busiestBackend.connList.Front() router.removeConn(busiestEle, ce) - conn := ce.Value.(*ConnWrapper) + conn := ce.Value.(*connWrapper) conn.phase = phaseRedirectNotify router.addConn(idlestEle, conn) conn.Redirect(idlestBackend.addr) } } -func (router *RandomRouter) removeBackendIfEmpty(be *list.Element) { - backend := be.Value.(*BackendWrapper) +func (router *ScoreBasedRouter) removeBackendIfEmpty(be *list.Element) { + backend := be.Value.(*backendWrapper) if backend.status == StatusCannotConnect && backend.connList.Len() == 0 { router.backends.Remove(be) } } -func (router *RandomRouter) Close() { +// Close implements Router.Close interface. +func (router *ScoreBasedRouter) Close() { router.Lock() defer router.Unlock() if router.cancelFunc != nil { diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 87710911..aeb63e2a 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/TiProxy/pkg/manager/router" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/TiProxy/pkg/util/errors" + "github.com/pingcap/TiProxy/pkg/util/waitgroup" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" @@ -46,6 +47,12 @@ type signalRedirect struct { newAddr string } +type redirectResult struct { + from string + to string + err error +} + // BackendConnManager migrates a session from one BackendConnection to another. // // The signal processing goroutine tries to migrate the session once it receives a signal. @@ -59,7 +66,7 @@ type BackendConnManager struct { connectionID uint64 authenticator *Authenticator cmdProcessor *CmdProcessor - eventReceiver router.ConnEventReceiver + eventReceiver unsafe.Pointer backendConn *BackendConnection // processLock makes redirecting and command processing exclusive. processLock sync.Mutex @@ -68,8 +75,11 @@ type BackendConnManager struct { // type *signalRedirect, it saves the last signal if there are multiple signals. // It will be set to nil after migration. signal unsafe.Pointer + // redirectResCh is used to notify the event receiver asynchronously. + redirectResCh chan *redirectResult // cancelFunc is used to cancel the signal processing goroutine. cancelFunc context.CancelFunc + wg waitgroup.WaitGroup } // NewBackendConnManager creates a BackendConnManager. @@ -79,6 +89,7 @@ func NewBackendConnManager(connectionID uint64) *BackendConnManager { cmdProcessor: NewCmdProcessor(), authenticator: &Authenticator{}, signalReceived: make(chan struct{}), + redirectResCh: make(chan *redirectResult, 1), } } @@ -102,8 +113,10 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, serverAddr string, c } mgr.cmdProcessor.capability = mgr.authenticator.capability childCtx, cancelFunc := context.WithCancel(ctx) - go mgr.processSignals(childCtx) mgr.cancelFunc = cancelFunc + mgr.wg.Run(func() { + mgr.processSignals(childCtx) + }) return nil } @@ -141,7 +154,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c } // Even if it meets an MySQL error, it may have changed the status, such as when executing multi-statements. if waitingRedirect && mgr.cmdProcessor.canRedirect() { - _ = mgr.tryRedirect(ctx) + mgr.tryRedirect(ctx) // Execute the held request no matter redirection succeeds or not. if holdRequest { _, err = mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendConn.PacketIO(), false) @@ -157,7 +170,15 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c // SetEventReceiver implements RedirectableConn.SetEventReceiver interface. // The receiver sends redirection signals and watches redirecting events. func (mgr *BackendConnManager) SetEventReceiver(receiver router.ConnEventReceiver) { - mgr.eventReceiver = receiver + atomic.StorePointer(&mgr.eventReceiver, unsafe.Pointer(&receiver)) +} + +func (mgr *BackendConnManager) getEventReceiver() router.ConnEventReceiver { + eventReceiver := (*router.ConnEventReceiver)(atomic.LoadPointer(&mgr.eventReceiver)) + if eventReceiver == nil { + return nil + } + return *eventReceiver } func (mgr *BackendConnManager) initSessionStates(backendIO *pnet.PacketIO, sessionStates string) error { @@ -182,19 +203,19 @@ func (mgr *BackendConnManager) querySessionStates() (sessionStates, sessionToken return } -// processSignals runs in a goroutine to receive redirection signals. -// It will then try to migrate the session. +// processSignals runs in a goroutine to: +// - Receive redirection signals and then try to migrate the session. +// - Send redirection results to the event receiver. func (mgr *BackendConnManager) processSignals(ctx context.Context) { for { select { - // Redirect the session immediately just in case the session is idle. - case _, ok := <-mgr.signalReceived: - if !ok { - return - } + case <-mgr.signalReceived: + // Redirect the session immediately just in case the session is idle. mgr.processLock.Lock() - _ = mgr.tryRedirect(ctx) + mgr.tryRedirect(ctx) mgr.processLock.Unlock() + case rs := <-mgr.redirectResCh: + mgr.notifyRedirectResult(ctx, rs) case <-ctx.Done(): return } @@ -203,54 +224,49 @@ func (mgr *BackendConnManager) processSignals(ctx context.Context) { // tryRedirect tries to migrate the session if the session is redirect-able. // NOTE: processLock should be held before calling this function. -func (mgr *BackendConnManager) tryRedirect(ctx context.Context) error { +func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { signal := (*signalRedirect)(atomic.LoadPointer(&mgr.signal)) if signal == nil { - return nil + return } - if !mgr.cmdProcessor.canRedirect() { - return nil + return } - from := mgr.backendConn.Addr() - to := signal.newAddr - var err error + rs := &redirectResult{ + from: mgr.backendConn.Addr(), + to: signal.newAddr, + } defer func() { - if err != nil { - mgr.eventReceiver.OnRedirectFail(from, to, mgr) - logutil.Logger(ctx).Warn("redirect connection failed", zap.String("from", from), zap.String("to", to), zap.Error(err)) - } else { - mgr.eventReceiver.OnRedirectSucceed(from, to, mgr) - logutil.Logger(ctx).Info("redirect connection succeeds", zap.String("from", from), zap.String("to", to)) - } + // The `mgr` won't be notified again before it calls `OnRedirectSucceed`, so simply `StorePointer` is also fine. + atomic.CompareAndSwapPointer(&mgr.signal, unsafe.Pointer(signal), nil) + // Notifying may block. Notify the receiver asynchronously to: + // - Reduce the latency of session migration + // - Avoid the risk of deadlock + mgr.redirectResCh <- rs }() - var sessionStates, sessionToken string - if sessionStates, sessionToken, err = mgr.querySessionStates(); err != nil { - return err + if sessionStates, sessionToken, rs.err = mgr.querySessionStates(); rs.err != nil { + return } - newConn := NewBackendConnection(to) - if err = newConn.Connect(); err != nil { - return err + newConn := NewBackendConnection(rs.to) + if rs.err = newConn.Connect(); rs.err != nil { + return } - if err = mgr.authenticator.handshakeSecondTime(newConn.PacketIO(), sessionToken); err == nil { - err = mgr.initSessionStates(newConn.PacketIO(), sessionStates) + if rs.err = mgr.authenticator.handshakeSecondTime(newConn.PacketIO(), sessionToken); rs.err == nil { + rs.err = mgr.initSessionStates(newConn.PacketIO(), sessionStates) } - if err != nil { + if rs.err != nil { if ignoredErr := newConn.Close(); ignoredErr != nil { logutil.Logger(ctx).Warn("close new backend connection failed", zap.Error(ignoredErr)) } - return err + return } if ignoredErr := mgr.backendConn.Close(); ignoredErr != nil { logutil.Logger(ctx).Warn("close previous backend connection failed", zap.Error(ignoredErr)) } mgr.backendConn = newConn - // The `mgr` won't be notified again before it calls `OnRedirectSucceed`, so simply `StorePointer` is also fine. - atomic.CompareAndSwapPointer(&mgr.signal, unsafe.Pointer(signal), nil) - return nil } // The original db in the auth info may be dropped during the session, so we need to authenticate with the current db. @@ -272,37 +288,73 @@ func (mgr *BackendConnManager) updateAuthInfoFromSessionStates(sessionStates []b func (mgr *BackendConnManager) Redirect(newAddr string) { // We do not use `chan signalRedirect` to avoid blocking. We cannot discard the signal when it blocks, // because only the latest signal matters. + // NOTE: BackendConnManager may be closing concurrently because of no lock. atomic.StorePointer(&mgr.signal, unsafe.Pointer(&signalRedirect{ newAddr: newAddr, })) - logutil.BgLogger().Info("received redirect command", zap.String("from", mgr.backendConn.Addr()), zap.String("to", newAddr)) select { case mgr.signalReceived <- struct{}{}: default: } } +// GetRedirectingAddr implements RedirectableConn.GetRedirectingAddr interface. +// It returns the goal backend address to redirect to. +func (mgr *BackendConnManager) GetRedirectingAddr() string { + signal := (*signalRedirect)(atomic.LoadPointer(&mgr.signal)) + if signal == nil { + return "" + } + return signal.newAddr +} + +func (mgr *BackendConnManager) notifyRedirectResult(ctx context.Context, rs *redirectResult) { + if rs == nil { + return + } + eventReceiver := mgr.getEventReceiver() + if eventReceiver == nil { + return + } + if rs.err != nil { + eventReceiver.OnRedirectFail(rs.from, rs.to, mgr) + logutil.Logger(ctx).Warn("redirect connection failed", zap.String("from", rs.from), + zap.String("to", rs.to), zap.Uint64("conn", mgr.connectionID), zap.Error(rs.err)) + } else { + eventReceiver.OnRedirectSucceed(rs.from, rs.to, mgr) + logutil.Logger(ctx).Info("redirect connection succeeds", zap.String("from", rs.from), + zap.String("to", rs.to), zap.Uint64("conn", mgr.connectionID)) + } +} + // Close releases all resources. func (mgr *BackendConnManager) Close() error { if mgr.cancelFunc != nil { mgr.cancelFunc() mgr.cancelFunc = nil } - mgr.processLock.Lock() - defer mgr.processLock.Unlock() + mgr.wg.Wait() + var err error - if mgr.eventReceiver != nil { - // Always notify the eventReceiver with the latest address. - signal := (*signalRedirect)(atomic.LoadPointer(&mgr.signal)) - if signal != nil { - mgr.eventReceiver.OnConnClosed(signal.newAddr, mgr) - } else if mgr.backendConn != nil { - mgr.eventReceiver.OnConnClosed(mgr.backendConn.Addr(), mgr) - } - } + var addr string + mgr.processLock.Lock() if mgr.backendConn != nil { + addr = mgr.backendConn.address err = mgr.backendConn.Close() mgr.backendConn = nil } + mgr.processLock.Unlock() + + eventReceiver := mgr.getEventReceiver() + if eventReceiver != nil { + // Notify the receiver if there's any event. + if len(mgr.redirectResCh) > 0 { + mgr.notifyRedirectResult(context.Background(), <-mgr.redirectResCh) + } + // Just notify it with the current address. + if len(addr) > 0 { + eventReceiver.OnConnClosed(addr, mgr) + } + } return err } diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 382619fb..9c9e9811 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/TiProxy/pkg/manager/router" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" + "github.com/pingcap/TiProxy/pkg/util/waitgroup" "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" ) @@ -83,7 +84,8 @@ type runner struct { // backendMgrTester encapsulates testSuite but is dedicated for BackendConnMgr. type backendMgrTester struct { *testSuite - t *testing.T + t *testing.T + closed bool } func newBackendMgrTester(t *testing.T) *backendMgrTester { @@ -92,18 +94,23 @@ func newBackendMgrTester(t *testing.T) *backendMgrTester { cfg.testSuiteConfig.initBackendConn = false } ts, clean := newTestSuite(t, tc, cfg) + tester := &backendMgrTester{ + testSuite: ts, + t: t, + } t.Cleanup(func() { clean() + if tester.closed { + return + } err := ts.mp.Close() require.NoError(t, err) - if ts.mp.eventReceiver != nil { - ts.mp.eventReceiver.(*mockEventReceiver).checkEvent(t, eventClose) + eventReceiver := ts.mp.getEventReceiver() + if eventReceiver != nil { + eventReceiver.(*mockEventReceiver).checkEvent(t, eventClose) } }) - return &backendMgrTester{ - testSuite: ts, - t: t, - } + return tester } // Define some common runners here to reduce code redundancy. @@ -175,16 +182,18 @@ func (ts *backendMgrTester) redirectAfterCmd4Proxy(clientIO, backendIO *pnet.Pac backend1 := ts.mp.backendConn err := ts.forwardCmd4Proxy(clientIO, backendIO) require.NoError(ts.t, err) - ts.mp.eventReceiver.(*mockEventReceiver).checkEvent(ts.t, eventSucceed) + ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventSucceed) require.NotEqual(ts.t, backend1, ts.mp.backendConn) + require.Len(ts.t, ts.mp.GetRedirectingAddr(), 0) return nil } func (ts *backendMgrTester) redirectFail4Proxy(clientIO, backendIO *pnet.PacketIO) error { backend1 := ts.mp.backendConn ts.mp.Redirect(ts.tc.backendListener.Addr().String()) - ts.mp.eventReceiver.(*mockEventReceiver).checkEvent(ts.t, eventFail) + ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventFail) require.Equal(ts.t, backend1, ts.mp.backendConn) + require.Len(ts.t, ts.mp.GetRedirectingAddr(), 0) return nil } @@ -210,7 +219,7 @@ func TestNormalRedirect(t *testing.T) { proxy: func(clientIO, backendIO *pnet.PacketIO) error { backend1 := ts.mp.backendConn ts.mp.Redirect(ts.tc.backendListener.Addr().String()) - ts.mp.eventReceiver.(*mockEventReceiver).checkEvent(t, eventSucceed) + ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed) require.NotEqual(t, backend1, ts.mp.backendConn) return nil }, @@ -314,7 +323,7 @@ func TestRedirectInTxn(t *testing.T) { backend1 := ts.mp.backendConn err := ts.forwardCmd4Proxy(clientIO, backendIO) require.NoError(t, err) - ts.mp.eventReceiver.(*mockEventReceiver).checkEvent(t, eventFail) + ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventFail) require.Equal(t, backend1, ts.mp.backendConn) return nil }, @@ -461,7 +470,7 @@ func TestSpecialCmds(t *testing.T) { proxy: func(clientIO, backendIO *pnet.PacketIO) error { backend1 := ts.mp.backendConn ts.mp.Redirect(ts.tc.backendListener.Addr().String()) - ts.mp.eventReceiver.(*mockEventReceiver).checkEvent(t, eventSucceed) + ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed) require.NotEqual(t, backend1, ts.mp.backendConn) return nil }, @@ -476,3 +485,40 @@ func TestSpecialCmds(t *testing.T) { } ts.runTests(runners) } + +// Test that closing the BackendConnMgr while it's receiving a redirection signal is OK. +func TestCloseWhileRedirect(t *testing.T) { + ts := newBackendMgrTester(t) + runners := []runner{ + // 1st handshake + { + client: ts.mc.authenticate, + proxy: ts.firstHandshake4Proxy, + backend: ts.handshake4Backend, + }, + // close and redirect concurrently + { + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + // Send an event to make Close() block at notifying. + addr := ts.tc.backendListener.Addr().String() + eventReceiver := ts.mp.getEventReceiver().(*mockEventReceiver) + eventReceiver.OnRedirectSucceed(addr, addr, ts.mp) + var wg waitgroup.WaitGroup + wg.Run(func() { + _ = ts.mp.Close() + ts.closed = true + }) + // Make sure the process goroutine finishes. + ts.mp.wg.Wait() + // Redirect() should not panic after Close(). + ts.mp.Redirect(addr) + eventReceiver.checkEvent(t, eventSucceed) + require.Equal(t, addr, ts.mp.GetRedirectingAddr()) + wg.Wait() + eventReceiver.checkEvent(t, eventClose) + return nil + }, + }, + } + ts.runTests(runners) +}