diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 2036a7d8..d295858c 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -132,6 +132,7 @@ type BackendConnManager struct { handshakeHandler HandshakeHandler ctxmap sync.Map connectionID uint64 + quitSource ErrorSource } // NewBackendConnManager creates a BackendConnManager. @@ -151,6 +152,7 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler // There are 2 types of signals, which may be sent concurrently. signalReceived: make(chan signalType, signalTypeNums), redirectResCh: make(chan *redirectResult, 1), + quitSource: SrcClientQuit, } return mgr } @@ -170,11 +172,14 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe mgr.clientIO = clientIO err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), mgr, clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) - mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err) if err != nil { + mgr.setQuitSourceByErr(err) + mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err) WriteUserError(clientIO, err, mgr.logger) return err } + mgr.resetQuitSource() + mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil) mgr.cmdProcessor.capability = mgr.authenticator.capability childCtx, cancelFunc := context.WithCancel(ctx) @@ -233,7 +238,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato // NOTE: should use DNS name as much as possible // Usually certs are signed with domain instead of IP addrs // And `RemoteAddr()` will return IP addr - backendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(addr, cn.RemoteAddr())) + backendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(addr, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn)) mgr.backendIO.Store(backendIO) mgr.setKeepAlive(mgr.config.HealthyKeepAlive) return backendIO, nil @@ -241,6 +246,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), func(err error, d time.Duration) { origErr = err + mgr.setQuitSourceByErr(err) mgr.handshakeHandler.OnHandshake(cctx, addr, err) }, ) @@ -264,9 +270,13 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato // ExecuteCmd forwards messages between the client and the backend. // If it finds that the session is ready for redirection, it migrates the session. -func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) error { +func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) (err error) { + defer func() { + mgr.setQuitSourceByErr(err) + }() if len(request) < 1 { - return mysql.ErrMalformPacket + err = mysql.ErrMalformPacket + return } cmd := request[0] startTime := time.Now() @@ -275,17 +285,18 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e switch mgr.closeStatus.Load() { case statusClosing, statusClosed: - return nil + return } defer mgr.resetCheckBackendTicker() waitingRedirect := atomic.LoadPointer(&mgr.signal) != nil - holdRequest, err := mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), waitingRedirect) + var holdRequest bool + holdRequest, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), waitingRedirect) if !holdRequest { addCmdMetrics(cmd, mgr.ServerAddr(), startTime) } if err != nil { if !IsMySQLError(err) { - return err + return } else { mgr.logger.Debug("got a mysql error", zap.Error(err)) } @@ -293,7 +304,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e if err == nil { switch cmd { case mysql.ComQuit: - return nil + return case mysql.ComSetOption: val := binary.LittleEndian.Uint16(request[1:]) switch val { @@ -304,12 +315,13 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e mgr.authenticator.capability &^= mysql.ClientMultiStatements mgr.cmdProcessor.capability &^= mysql.ClientMultiStatements default: - return errors.Errorf("unrecognized set_option value:%d", val) + err = errors.Errorf("unrecognized set_option value:%d", val) + return } case mysql.ComChangeUser: username, db := pnet.ParseChangeUser(request) mgr.authenticator.changeUser(username, db) - return nil + return } } // Even if it meets an MySQL error, it may have changed the status, such as when executing multi-statements. @@ -320,7 +332,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e _, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), false) addCmdMetrics(cmd, mgr.ServerAddr(), startTime) if err != nil && !IsMySQLError(err) { - return err + return } } else if mgr.closeStatus.Load() == statusNotifyClose { mgr.tryGracefulClose(ctx) @@ -329,7 +341,8 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e } } // Ignore MySQL errors, only return unexpected errors. - return nil + err = nil + return } // SetEventReceiver implements RedirectableConn.SetEventReceiver interface. @@ -428,6 +441,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { // If the backend connection is closed, also close the client connection. // Otherwise, if the client is idle, the mgr will keep retrying. if errors.Is(rs.err, net.ErrClosed) || pnet.IsDisconnectError(rs.err) || errors.Is(rs.err, os.ErrDeadlineExceeded) { + mgr.quitSource = SrcBackendQuit if ignoredErr := mgr.clientIO.GracefulClose(); ignoredErr != nil { mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(ignoredErr)) } @@ -438,17 +452,20 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { return } + defer mgr.resetQuitSource() var cn net.Conn cn, rs.err = net.DialTimeout("tcp", rs.to, DialTimeout) if rs.err != nil { + mgr.quitSource = SrcBackendQuit mgr.handshakeHandler.OnHandshake(mgr, rs.to, rs.err) return } - newBackendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(rs.to, cn.RemoteAddr())) + newBackendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(rs.to, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn)) if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, mgr.clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil { rs.err = mgr.initSessionStates(newBackendIO, sessionStates) } else { + mgr.setQuitSourceByErr(rs.err) mgr.handshakeHandler.OnHandshake(mgr, newBackendIO.RemoteAddr().String(), rs.err) } if rs.err != nil { @@ -538,6 +555,7 @@ func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context) { if !mgr.cmdProcessor.finishedTxn() { return } + mgr.quitSource = SrcProxyQuit // Closing clientIO will cause the whole connection to be closed. if err := mgr.clientIO.GracefulClose(); err != nil { mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(err)) @@ -557,6 +575,7 @@ func (mgr *BackendConnManager) checkBackendActive() { if !backendIO.IsPeerActive() { mgr.logger.Info("backend connection is closed, close client connection", zap.Stringer("client", mgr.clientIO.RemoteAddr()), zap.Stringer("backend", backendIO.RemoteAddr())) + mgr.quitSource = SrcBackendQuit if err := mgr.clientIO.GracefulClose(); err != nil { mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(err)) } @@ -602,6 +621,10 @@ func (mgr *BackendConnManager) ClientOutBytes() uint64 { return mgr.clientIO.OutBytes() } +func (mgr *BackendConnManager) QuitSource() ErrorSource { + return mgr.quitSource +} + func (mgr *BackendConnManager) SetValue(key, val any) { mgr.ctxmap.Store(key, val) } @@ -675,3 +698,25 @@ func (mgr *BackendConnManager) setKeepAlive(cfg config.KeepAlive) { mgr.logger.Warn("failed to set keepalive", zap.Error(err), zap.Stringer("backend", backendIO.RemoteAddr())) } } + +// quitSource will be read by OnHandshake and OnConnClose, so setQuitSourceByErr should be called before them. +func (mgr *BackendConnManager) setQuitSourceByErr(err error) { + // Do not update the source if err is nil. It may be already be set. + if err == nil { + return + } + if errors.Is(err, ErrBackendConn) { + mgr.quitSource = SrcBackendQuit + } else if IsMySQLError(err) { + mgr.quitSource = SrcClientErr + } else if !errors.Is(err, ErrClientConn) { + mgr.quitSource = SrcProxyErr + } +} + +func (mgr *BackendConnManager) resetQuitSource() { + // SrcClientQuit is by default. + // Sometimes ErrClientConn is caused by GracefulClose and the quitSource is already set. + // Error maybe set during handshake for OnHandshake. If handshake finally succeeds, we reset it. + mgr.quitSource = SrcClientQuit +} diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 54f26317..fc07ea9e 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -248,6 +248,7 @@ func TestNormalRedirect(t *testing.T) { ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed) require.NotEqual(t, backend1, ts.mp.backendIO.Load()) + require.Equal(t, SrcClientQuit, ts.mp.QuitSource()) return nil }, backend: ts.redirectSucceed4Backend, @@ -352,6 +353,7 @@ func TestRedirectInTxn(t *testing.T) { require.NoError(t, err) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventFail) require.Equal(t, backend1, ts.mp.backendIO.Load()) + require.Equal(t, SrcClientQuit, ts.mp.QuitSource()) return nil }, backend: func(packetIO *pnet.PacketIO) error { @@ -388,6 +390,12 @@ func TestConnectFail(t *testing.T) { return ts.mb.authenticate(ts.tc.backendIO) }, }, + { + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + require.Equal(t, SrcClientErr, ts.mp.QuitSource()) + return nil + }, + }, } ts.runTests(runners) } @@ -499,6 +507,7 @@ func TestSpecialCmds(t *testing.T) { ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed) require.NotEqual(t, backend1, ts.mp.backendIO.Load()) + require.Equal(t, SrcClientQuit, ts.mp.QuitSource()) return nil }, backend: func(packetIO *pnet.PacketIO) error { @@ -599,6 +608,12 @@ func TestCustomHandshake(t *testing.T) { }, backend: ts.redirectSucceed4Backend, }, + { + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + require.Equal(t, SrcClientQuit, ts.mp.QuitSource()) + return nil + }, + }, } ts.runTests(runners) } @@ -623,6 +638,12 @@ func TestGracefulCloseWhenIdle(t *testing.T) { { proxy: ts.checkConnClosed4Proxy, }, + { + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + require.Equal(t, SrcProxyQuit, ts.mp.QuitSource()) + return nil + }, + }, } ts.runTests(runners) } @@ -661,6 +682,12 @@ func TestGracefulCloseWhenActive(t *testing.T) { { proxy: ts.checkConnClosed4Proxy, }, + { + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + require.Equal(t, SrcProxyQuit, ts.mp.QuitSource()) + return nil + }, + }, } ts.runTests(runners) } @@ -685,14 +712,21 @@ func TestGracefulCloseBeforeHandshake(t *testing.T) { { proxy: ts.checkConnClosed4Proxy, }, + { + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + require.Equal(t, SrcProxyQuit, ts.mp.QuitSource()) + return nil + }, + }, } ts.runTests(runners) } func TestHandlerReturnError(t *testing.T) { tests := []struct { - cfg cfgOverrider - errMsg string + cfg cfgOverrider + errMsg string + quitSource ErrorSource }{ { cfg: func(config *testConfig) { @@ -700,7 +734,8 @@ func TestHandlerReturnError(t *testing.T) { return nil, errors.New("mocked error") } }, - errMsg: "mocked error", + errMsg: "mocked error", + quitSource: SrcProxyErr, }, { cfg: func(config *testConfig) { @@ -708,7 +743,8 @@ func TestHandlerReturnError(t *testing.T) { return errors.New("mocked error") } }, - errMsg: "mocked error", + errMsg: "mocked error", + quitSource: SrcProxyErr, }, { // TODO: make it fail faster. @@ -717,7 +753,8 @@ func TestHandlerReturnError(t *testing.T) { return router.NewStaticRouter(nil), nil } }, - errMsg: connectErrMsg, + errMsg: connectErrMsg, + quitSource: SrcProxyErr, }, } for _, test := range tests { @@ -732,6 +769,7 @@ func TestHandlerReturnError(t *testing.T) { proxy: func(clientIO, backendIO *pnet.PacketIO) error { err := ts.mp.Connect(context.Background(), clientIO, ts.mp.frontendTLSConfig, ts.mp.backendTLSConfig) require.Error(t, err) + require.Equal(t, test.quitSource, ts.mp.QuitSource()) return nil }, backend: nil, @@ -761,6 +799,9 @@ func TestGetBackendIO(t *testing.T) { if err != nil && len(s) > 0 { badAddrs[s] = struct{}{} } + if err != nil { + require.Equal(t, SrcProxyErr, connContext.QuitSource()) + } }, } mgr := NewBackendConnManager(logger.CreateLoggerForTest(t), handler, 0, &BCConfig{}) @@ -865,6 +906,12 @@ func TestBackendInactive(t *testing.T) { return packetIO.Close() }, }, + { + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + require.Equal(t, SrcBackendQuit, ts.mp.QuitSource()) + return nil + }, + }, } ts.runTests(runners) } diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index 6d8afb3a..be160363 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -27,6 +27,11 @@ const ( capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB" ) +var ( + ErrClientConn = errors.New("this is an error from client") + ErrBackendConn = errors.New("this is an error from backend") +) + // UserError is returned to the client. // err is used to log and userMsg is used to report to the user. type UserError struct { diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index 79908b67..05c064eb 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -28,6 +28,41 @@ const ( ConnContextKeyTLSState ConnContextKey = "tls-state" ) +type ErrorSource int + +const ( + // SrcClientQuit includes: client quit; bad client conn + SrcClientQuit ErrorSource = iota + // SrcClientErr includes: wrong password; mal format packet + SrcClientErr + // SrcProxyQuit includes: proxy graceful shutdown + SrcProxyQuit + // SrcProxyErr includes: cannot get backend list; capability negotiation + SrcProxyErr + // SrcBackendQuit includes: backend quit + SrcBackendQuit + // SrcBackendErr is reserved + SrcBackendErr +) + +func (es ErrorSource) String() string { + switch es { + case SrcClientQuit: + return "client quit" + case SrcClientErr: + return "client error" + case SrcProxyQuit: + return "proxy shutdown" + case SrcProxyErr: + return "proxy error" + case SrcBackendQuit: + return "backend quit" + case SrcBackendErr: + return "backend error" + } + return "unknown" +} + var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil) type ConnContext interface { @@ -35,6 +70,7 @@ type ConnContext interface { ServerAddr() string ClientInBytes() uint64 ClientOutBytes() uint64 + QuitSource() ErrorSource SetValue(key, val any) Value(key any) any } diff --git a/pkg/proxy/client/client_conn.go b/pkg/proxy/client/client_conn.go index 8d420684..0454bb97 100644 --- a/pkg/proxy/client/client_conn.go +++ b/pkg/proxy/client/client_conn.go @@ -17,9 +17,7 @@ package client import ( "context" "crypto/tls" - "io" "net" - "os" "github.com/pingcap/TiProxy/lib/util/errors" "github.com/pingcap/TiProxy/pkg/proxy/backend" @@ -28,10 +26,6 @@ import ( "go.uber.org/zap" ) -var ( - ErrClientConn = errors.New("this is an error from client") -) - type ClientConnection struct { logger *zap.Logger frontendTLSConfig *tls.Config // the TLS config to connect to clients. @@ -44,7 +38,7 @@ func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *t hsHandler backend.HandshakeHandler, connID uint64, bcConfig *backend.BCConfig) *ClientConnection { bemgr := backend.NewBackendConnManager(logger.Named("be"), hsHandler, connID, bcConfig) opts := make([]pnet.PacketIOption, 0, 2) - opts = append(opts, pnet.WithWrapError(ErrClientConn)) + opts = append(opts, pnet.WithWrapError(backend.ErrClientConn)) if bcConfig.ProxyProtocol { opts = append(opts, pnet.WithProxy) } @@ -72,12 +66,12 @@ func (cc *ClientConnection) Run(ctx context.Context) { } clean: - clientErr := errors.Is(err, ErrClientConn) - // EOF: client closes; DeadlineExceeded: graceful shutdown; Closed: shut down. - if clientErr && (errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, net.ErrClosed)) { - return + src := cc.connMgr.QuitSource() + switch src { + case backend.SrcClientQuit, backend.SrcClientErr, backend.SrcProxyQuit: + default: + cc.logger.Info(msg, zap.Error(err), zap.Stringer("quit source", src)) } - cc.logger.Info(msg, zap.Error(err), zap.Bool("clientErr", clientErr), zap.Bool("serverErr", !clientErr)) } func (cc *ClientConnection) processMsg(ctx context.Context) error {