diff --git a/server/conn.go b/server/conn.go index 44f43920a8b37..4f93694c680fb 100644 --- a/server/conn.go +++ b/server/conn.go @@ -1163,6 +1163,13 @@ func (cc *clientConn) Run(ctx context.Context) { return } + // Should check InTxn() to avoid execute `begin` stmt. + if cc.server.inShutdownMode.Load() { + if !cc.ctx.GetSessionVars().InTxn() { + return + } + } + if !atomic.CompareAndSwapInt32(&cc.status, connStatusReading, connStatusDispatching) { return } diff --git a/server/conn_test.go b/server/conn_test.go index c540b1784793d..a885a0336f765 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -764,7 +764,7 @@ func TestConnExecutionTimeout(t *testing.T) { } func TestShutDown(t *testing.T) { - store := testkit.CreateMockStore(t) + store, dom := testkit.CreateMockStoreAndDomain(t) cc := &clientConn{} se, err := session.CreateSession4Test(store) @@ -776,6 +776,34 @@ func TestShutDown(t *testing.T) { // assert ErrQueryInterrupted err = cc.handleQuery(context.Background(), "select 1") require.Equal(t, executor.ErrQueryInterrupted, err) + + cfg := newTestConfig() + cfg.Port = 0 + cfg.Status.StatusPort = 0 + drv := NewTiDBDriver(store) + srv, err := NewServer(cfg, drv) + require.NoError(t, err) + srv.SetDomain(dom) + + cc = &clientConn{server: srv} + cc.setCtx(tc) + + // test in txn + srv.clients[cc.connectionID+1] = cc + cc.getCtx().GetSessionVars().SetInTxn(true) + + waitTime := 100 * time.Millisecond + begin := time.Now() + srv.DrainClients(waitTime, waitTime) + require.Greater(t, time.Since(begin), waitTime) + + // test not in txn + srv.clients[cc.connectionID+2] = cc + cc.getCtx().GetSessionVars().SetInTxn(false) + + begin = time.Now() + srv.DrainClients(waitTime, waitTime) + require.Less(t, time.Since(begin), waitTime) } type snapshotCache interface { diff --git a/server/server.go b/server/server.go index 8bb184261b046..0d9c9fa9096f0 100644 --- a/server/server.go +++ b/server/server.go @@ -34,7 +34,6 @@ import ( "crypto/tls" "fmt" "io" - "math/rand" "net" "net/http" //nolint:goimports // For pprof @@ -334,9 +333,6 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { } } - // Init rand seed for randomBuf() - rand.Seed(time.Now().UTC().UnixNano()) - variable.RegisterStatistics(s) return s, nil @@ -556,8 +552,6 @@ func (s *Server) closeListener() { metrics.ServerEventCounter.WithLabelValues(metrics.EventClose).Inc() } -var gracefulCloseConnectionsTimeout = 15 * time.Second - // Close closes the server. func (s *Server) Close() { s.startShutdown() @@ -891,6 +885,9 @@ func (s *Server) DrainClients(drainWait time.Duration, cancelWait time.Duration) go func() { defer close(allDone) for _, conn := range conns { + if !conn.getCtx().GetSessionVars().InTxn() { + continue + } select { case <-conn.quit: case <-quitWaitingForConns: diff --git a/tidb-server/main.go b/tidb-server/main.go index d70efc5a9e570..52d2268fc09de 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -224,9 +224,9 @@ func main() { terror.RegisterFinish() exited := make(chan struct{}) - signal.SetupSignalHandler(func(graceful bool) { + signal.SetupSignalHandler(func() { svr.Close() - cleanup(svr, storage, dom, graceful) + cleanup(svr, storage, dom) cpuprofile.StopCPUProfiler() close(exited) }) @@ -839,7 +839,7 @@ func closeDomainAndStorage(storage kv.Storage, dom *domain.Domain) { // We should better provider a dynamic way to set this value. var gracefulCloseConnectionsTimeout = 15 * time.Second -func cleanup(svr *server.Server, storage kv.Storage, dom *domain.Domain, _ bool) { +func cleanup(svr *server.Server, storage kv.Storage, dom *domain.Domain) { dom.StopAutoAnalyze() drainClientWait := gracefulCloseConnectionsTimeout diff --git a/util/signal/signal_posix.go b/util/signal/signal_posix.go index 8e9e25308c0ef..8b416b93c543e 100644 --- a/util/signal/signal_posix.go +++ b/util/signal/signal_posix.go @@ -27,7 +27,7 @@ import ( ) // SetupSignalHandler setup signal handler for TiDB Server -func SetupSignalHandler(shutdownFunc func(bool)) { +func SetupSignalHandler(shutdownFunc func()) { usrDefSignalChan := make(chan os.Signal, 1) signal.Notify(usrDefSignalChan, syscall.SIGUSR1) @@ -52,6 +52,6 @@ func SetupSignalHandler(shutdownFunc func(bool)) { go func() { sig := <-closeSignalChan logutil.BgLogger().Info("got signal to exit", zap.Stringer("signal", sig)) - shutdownFunc(sig != syscall.SIGHUP) + shutdownFunc() }() }