diff --git a/server/BUILD.bazel b/server/BUILD.bazel index e346e35c466b1..79deb82037ab3 100644 --- a/server/BUILD.bazel +++ b/server/BUILD.bazel @@ -116,6 +116,7 @@ go_library( "@org_golang_google_grpc//channelz/service", "@org_golang_google_grpc//keepalive", "@org_golang_google_grpc//peer", + "@org_uber_go_atomic//:atomic", "@org_uber_go_zap//:zap", ], ) diff --git a/server/conn.go b/server/conn.go index ecd5977f0d101..41133a9bec358 100644 --- a/server/conn.go +++ b/server/conn.go @@ -172,6 +172,7 @@ func newClientConn(s *Server) *clientConn { status: connStatusDispatching, lastActive: time.Now(), authPlugin: mysql.AuthNativePassword, + quit: make(chan struct{}), ppEnabled: s.cfg.ProxyProtocol.Networks != "", } } @@ -215,6 +216,8 @@ type clientConn struct { sync.RWMutex cancelFunc context.CancelFunc } + // quit is close once clientConn quit Run(). + quit chan struct{} extensions *extension.SessionExtensions // Proxy Protocol Enabled @@ -1093,6 +1096,12 @@ func (cc *clientConn) Run(ctx context.Context) { terror.Log(err) metrics.PanicCounter.WithLabelValues(metrics.LabelSession).Inc() } + if atomic.LoadInt32(&cc.status) != connStatusShutdown { + err := cc.Close() + terror.Log(err) + } + + close(cc.quit) }() // Usually, client connection status changes between [dispatching] <=> [reading]. @@ -1101,6 +1110,13 @@ func (cc *clientConn) Run(ctx context.Context) { // The client connection would detect the events when it fails to change status // by CAS operation, it would then take some actions accordingly. for { + // Close connection between txn when we are going to shutdown server. + if cc.server.inShutdownMode.Load() { + if !cc.ctx.GetSessionVars().InTxn() { + return + } + } + if !atomic.CompareAndSwapInt32(&cc.status, connStatusDispatching, connStatusReading) || // The judge below will not be hit by all means, // But keep it stayed as a reminder and for the code reference for connStatusWaitShutdown. @@ -1110,6 +1126,7 @@ func (cc *clientConn) Run(ctx context.Context) { cc.alloc.Reset() // close connection when idle time is more than wait_timeout + // default 28800(8h), FIXME: should not block at here when we kill the connection. waitTimeout := cc.getSessionVarsWaitTimeout(ctx) cc.pkt.setReadTimeout(time.Duration(waitTimeout) * time.Second) start := time.Now() @@ -1196,22 +1213,6 @@ func (cc *clientConn) Run(ctx context.Context) { } } -// ShutdownOrNotify will Shutdown this client connection, or do its best to notify. -func (cc *clientConn) ShutdownOrNotify() bool { - if (cc.ctx.Status() & mysql.ServerStatusInTrans) > 0 { - return false - } - // If the client connection status is reading, it's safe to shutdown it. - if atomic.CompareAndSwapInt32(&cc.status, connStatusReading, connStatusShutdown) { - return true - } - // If the client connection status is dispatching, we can't shutdown it immediately, - // so set the status to WaitShutdown as a notification, the loop in clientConn.Run - // will detect it and then exit. - atomic.StoreInt32(&cc.status, connStatusWaitShutdown) - return false -} - func errStrForLog(err error, enableRedactLog bool) string { if enableRedactLog { // currently, only ErrParse is considered when enableRedactLog because it may contain sensitive information like diff --git a/server/conn_test.go b/server/conn_test.go index fa3b9d5317a96..c540b1784793d 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -778,31 +778,6 @@ func TestShutDown(t *testing.T) { require.Equal(t, executor.ErrQueryInterrupted, err) } -func TestShutdownOrNotify(t *testing.T) { - store := testkit.CreateMockStore(t) - se, err := session.CreateSession4Test(store) - require.NoError(t, err) - tc := &TiDBContext{ - Session: se, - stmts: make(map[int]*TiDBStatement), - } - cc := &clientConn{ - connectionID: 1, - server: &Server{ - capability: defaultCapability, - }, - status: connStatusWaitShutdown, - } - cc.setCtx(tc) - require.False(t, cc.ShutdownOrNotify()) - cc.status = connStatusReading - require.True(t, cc.ShutdownOrNotify()) - require.Equal(t, connStatusShutdown, cc.status) - cc.status = connStatusDispatching - require.False(t, cc.ShutdownOrNotify()) - require.Equal(t, connStatusWaitShutdown, cc.status) -} - type snapshotCache interface { SnapCacheHitCount() int } diff --git a/server/http_status.go b/server/http_status.go index 8070bd91e2b99..20aa534a7827f 100644 --- a/server/http_status.go +++ b/server/http_status.go @@ -539,7 +539,7 @@ func (s *Server) handleStatus(w http.ResponseWriter, req *http.Request) { // If the server is in the process of shutting down, return a non-200 status. // It is important not to return status{} as acquiring the s.ConnectionCount() // acquires a lock that may already be held by the shutdown process. - if s.inShutdownMode { + if !s.health.Load() { w.WriteHeader(http.StatusInternalServerError) return } diff --git a/server/server.go b/server/server.go index 3ab08629b232a..8bb184261b046 100644 --- a/server/server.go +++ b/server/server.go @@ -72,6 +72,7 @@ import ( "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sys/linux" "github.com/pingcap/tidb/util/timeutil" + uatomic "go.uber.org/atomic" "go.uber.org/zap" "google.golang.org/grpc" ) @@ -129,18 +130,21 @@ type Server struct { driver IDriver listener net.Listener socket net.Listener - rwlock sync.RWMutex concurrentLimiter *TokenLimiter - clients map[uint64]*clientConn - capability uint32 - dom *domain.Domain - globalConnID util.GlobalConnID + + rwlock sync.RWMutex + clients map[uint64]*clientConn + + capability uint32 + dom *domain.Domain + globalConnID util.GlobalConnID statusAddr string statusListener net.Listener statusServer *http.Server grpcServer *grpc.Server - inShutdownMode bool + inShutdownMode *uatomic.Bool + health *uatomic.Bool sessionMapMutex sync.Mutex internalSessions map[interface{}]struct{} @@ -209,6 +213,8 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { globalConnID: util.NewGlobalConnID(0, true), internalSessions: make(map[interface{}]struct{}, 100), printMDLLogTime: time.Now(), + health: uatomic.NewBool(true), + inShutdownMode: uatomic.NewBool(false), } s.capability = defaultCapability setTxnScope() @@ -396,7 +402,7 @@ func (s *Server) Run() error { } // If error should be reported and exit the server it can be sent on this // channel. Otherwise, end with sending a nil error to signal "done" - errChan := make(chan error) + errChan := make(chan error, 2) go s.startNetworkListener(s.listener, false, errChan) go s.startNetworkListener(s.socket, true, errChan) err := <-errChan @@ -416,7 +422,7 @@ func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool, if err != nil { if opErr, ok := err.(*net.OpError); ok { if opErr.Err.Error() == "use of closed network connection" { - if s.inShutdownMode { + if s.inShutdownMode.Load() { errChan <- nil } else { errChan <- err @@ -436,6 +442,8 @@ func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool, return } + logutil.BgLogger().Debug("accept new connection success") + clientConn := s.newConn(conn) if isUnixSocket { var ( @@ -507,10 +515,8 @@ func (s *Server) checkAuditPlugin(clientConn *clientConn) error { } func (s *Server) startShutdown() { - s.rwlock.RLock() logutil.BgLogger().Info("setting tidb-server to report unhealthy (shutting-down)") - s.inShutdownMode = true - s.rwlock.RUnlock() + s.health.Store(false) // give the load balancer a chance to receive a few unhealthy health reports // before acquiring the s.rwlock and blocking connections. waitTime := time.Duration(s.cfg.GracefulWaitBeforeShutdown) * time.Second @@ -520,12 +526,7 @@ func (s *Server) startShutdown() { } } -// Close closes the server. -func (s *Server) Close() { - s.startShutdown() - s.rwlock.Lock() // prevent new connections - defer s.rwlock.Unlock() - +func (s *Server) closeListener() { if s.listener != nil { err := s.listener.Close() terror.Log(errors.Trace(err)) @@ -555,6 +556,34 @@ func (s *Server) Close() { metrics.ServerEventCounter.WithLabelValues(metrics.EventClose).Inc() } +var gracefulCloseConnectionsTimeout = 15 * time.Second + +// Close closes the server. +func (s *Server) Close() { + s.startShutdown() + s.rwlock.Lock() // // prevent new connections + defer s.rwlock.Unlock() + s.inShutdownMode.Store(true) + s.closeListener() +} + +func (s *Server) registerConn(conn *clientConn) bool { + s.rwlock.Lock() + defer s.rwlock.Unlock() + connections := len(s.clients) + + logger := logutil.BgLogger() + if s.inShutdownMode.Load() { + logger.Info("close connection directly when shutting down") + terror.Log(closeConn(conn, connections)) + return false + } + s.clients[conn.connectionID] = conn + connections = len(s.clients) + metrics.ConnGauge.Set(float64(connections)) + return true +} + // onConn runs in its own goroutine, handles queries from this connection. func (s *Server) onConn(conn *clientConn) { // init the connInfo @@ -583,6 +612,7 @@ func (s *Server) onConn(conn *clientConn) { } ctx := logutil.WithConnID(context.Background(), conn.connectionID) + if err := conn.handshake(ctx); err != nil { conn.onExtensionConnEvent(extension.ConnHandshakeRejected, err) if plugin.IsEnable(plugin.Audit) && conn.getCtx() != nil { @@ -624,11 +654,10 @@ func (s *Server) onConn(conn *clientConn) { terror.Log(conn.Close()) logutil.Logger(ctx).Debug("connection closed") }() - s.rwlock.Lock() - s.clients[conn.connectionID] = conn - connections := len(s.clients) - s.rwlock.Unlock() - metrics.ConnGauge.Set(float64(connections)) + + if !s.registerConn(conn) { + return + } sessionVars := conn.ctx.GetSessionVars() sessionVars.ConnectionInfo = conn.connectInfo() @@ -784,7 +813,7 @@ func (s *Server) Kill(connectionID uint64, query bool) { // this, it will end the dispatch loop and exit. atomic.StoreInt32(&conn.status, connStatusWaitShutdown) } - killConn(conn) + killQuery(conn) } // UpdateTLSConfig implements the SessionManager interface. @@ -796,7 +825,7 @@ func (s *Server) getTLSConfig() *tls.Config { return (*tls.Config)(atomic.LoadPointer(&s.tlsConfig)) } -func killConn(conn *clientConn) { +func killQuery(conn *clientConn) { sessVars := conn.ctx.GetSessionVars() atomic.StoreUint32(&sessVars.Killed, 1) conn.mu.RLock() @@ -824,7 +853,8 @@ func (s *Server) KillSysProcesses() { } } -// KillAllConnections kills all connections when server is not gracefully shutdown. +// KillAllConnections implements the SessionManager interface. +// KillAllConnections kills all connections. func (s *Server) KillAllConnections() { logutil.BgLogger().Info("[server] kill all connections.") @@ -835,73 +865,53 @@ func (s *Server) KillAllConnections() { if err := conn.closeWithoutLock(); err != nil { terror.Log(err) } - killConn(conn) + killQuery(conn) } s.KillSysProcesses() } -var gracefulCloseConnectionsTimeout = 15 * time.Second - -// TryGracefulDown will try to gracefully close all connection first with timeout. if timeout, will close all connection directly. -func (s *Server) TryGracefulDown() { - ctx, cancel := context.WithTimeout(context.Background(), gracefulCloseConnectionsTimeout) - defer cancel() - done := make(chan struct{}) - go func() { - s.GracefulDown(ctx, done) - }() - select { - case <-ctx.Done(): - s.KillAllConnections() - case <-done: - return - } -} - -// GracefulDown waits all clients to close. -func (s *Server) GracefulDown(ctx context.Context, done chan struct{}) { - logutil.Logger(ctx).Info("[server] graceful shutdown.") - metrics.ServerEventCounter.WithLabelValues(metrics.EventGracefulDown).Inc() +// DrainClients drain all connections in drainWait. +// After drainWait duration, we kill all connections still not quit explicitly and wait for cancelWait. +func (s *Server) DrainClients(drainWait time.Duration, cancelWait time.Duration) { + logger := logutil.BgLogger() + logger.Info("start drain clients") - count := s.ConnectionCount() - for i := 0; count > 0; i++ { - s.kickIdleConnection() + conns := make(map[uint64]*clientConn) - count = s.ConnectionCount() - if count == 0 { - break - } - // Print information for every 30s. - if i%30 == 0 { - logutil.Logger(ctx).Info("graceful shutdown...", zap.Int("conn count", count)) - } - ticker := time.After(time.Second) - select { - case <-ctx.Done(): - return - case <-ticker: - } + s.rwlock.Lock() + for k, v := range s.clients { + conns[k] = v } - close(done) -} + s.rwlock.Unlock() -func (s *Server) kickIdleConnection() { - var conns []*clientConn - s.rwlock.RLock() - for _, cc := range s.clients { - if cc.ShutdownOrNotify() { - // Shutdowned conn will be closed by us, and notified conn will exist themselves. - conns = append(conns, cc) + allDone := make(chan struct{}) + quitWaitingForConns := make(chan struct{}) + defer close(quitWaitingForConns) + go func() { + defer close(allDone) + for _, conn := range conns { + select { + case <-conn.quit: + case <-quitWaitingForConns: + return + } } + }() + + select { + case <-allDone: + logger.Info("all sessions quit in drain wait time") + case <-time.After(drainWait): + logger.Info("timeout waiting all sessions quit") } - s.rwlock.RUnlock() - for _, cc := range conns { - err := cc.Close() - if err != nil { - logutil.BgLogger().Error("close connection", zap.Error(err)) - } + s.KillAllConnections() + + select { + case <-allDone: + case <-time.After(cancelWait): + logger.Warn("some sessions do not quit in cancel wait time") } } diff --git a/tidb-server/main.go b/tidb-server/main.go index d8b830fc9c9a1..0ea91bf3a3a79 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -835,17 +835,23 @@ func closeDomainAndStorage(storage kv.Storage, dom *domain.Domain) { terror.Log(errors.Trace(err)) } +var gracefulCloseConnectionsTimeout = 15 * time.Second + func cleanup(svr *server.Server, storage kv.Storage, dom *domain.Domain, graceful bool) { dom.StopAutoAnalyze() + + var drainClientWait time.Duration if graceful { - done := make(chan struct{}) - svr.GracefulDown(context.Background(), done) - // Kill sys processes such as auto analyze. Otherwise, tidb-server cannot exit until auto analyze is finished. - // See https://github.com/pingcap/tidb/issues/40038 for details. - svr.KillSysProcesses() + drainClientWait = 1<<63 - 1 } else { - svr.TryGracefulDown() + drainClientWait = gracefulCloseConnectionsTimeout } + cancelClientWait := time.Second * 1 + svr.DrainClients(drainClientWait, cancelClientWait) + + // Kill sys processes such as auto analyze. Otherwise, tidb-server cannot exit until auto analyze is finished. + // See https://github.com/pingcap/tidb/issues/40038 for details. + svr.KillSysProcesses() plugin.Shutdown(context.Background()) closeDomainAndStorage(storage, dom) disk.CleanUp()