From 9a3d280cad765ef1e3e3bf127da22ae4a27e4f20 Mon Sep 17 00:00:00 2001 From: djshow832 Date: Sun, 5 Mar 2023 17:04:12 +0800 Subject: [PATCH] backend, net: make connection concurrency safe (#232) --- pkg/proxy/backend/backend_conn_mgr.go | 47 ++++++++++++---------- pkg/proxy/backend/backend_conn_mgr_test.go | 28 ++++++------- pkg/proxy/net/packetio.go | 40 +++++++++--------- pkg/proxy/net/packetio_options.go | 4 +- pkg/proxy/net/packetio_test.go | 2 +- pkg/proxy/net/proxy_test.go | 2 +- pkg/proxy/net/tls.go | 19 +++++---- 7 files changed, 74 insertions(+), 68 deletions(-) diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index a3018afc..0e3ad7ba 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -120,9 +120,10 @@ type BackendConnManager struct { closeStatus atomic.Int32 checkBackendTicker *time.Ticker // cancelFunc is used to cancel the signal processing goroutine. - cancelFunc context.CancelFunc - clientIO *pnet.PacketIO - backendIO *pnet.PacketIO + cancelFunc context.CancelFunc + clientIO *pnet.PacketIO + // backendIO may be written during redirection and be read in ExecuteCmd/Redirect/setKeepalive. + backendIO atomic.Pointer[pnet.PacketIO] backendTLS *tls.Config handshakeHandler HandshakeHandler ctxmap sync.Map @@ -228,8 +229,9 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato // NOTE: should use DNS name as much as possible // Usually certs are signed with domain instead of IP addrs // And `RemoteAddr()` will return IP addr - mgr.backendIO = pnet.NewPacketIO(cn, pnet.WithRemoteAddr(addr, cn.RemoteAddr())) - return mgr.backendIO, nil + backendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(addr, cn.RemoteAddr())) + mgr.backendIO.Store(backendIO) + return backendIO, nil }, backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), func(err error, d time.Duration) { @@ -245,7 +247,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato mgr.logger.Error("get backend failed", zap.Duration("duration", duration), zap.NamedError("last_err", origErr)) } else if duration >= 3*time.Second { mgr.logger.Warn("get backend slow", zap.Duration("duration", duration), zap.NamedError("last_err", origErr), - zap.Stringer("backend_addr", mgr.backendIO.RemoteAddr())) + zap.String("backend_addr", mgr.ServerAddr())) } if err != nil && errors.Is(err, context.DeadlineExceeded) { if origErr != nil { @@ -272,7 +274,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e } defer mgr.resetCheckBackendTicker() waitingRedirect := atomic.LoadPointer(&mgr.signal) != nil - holdRequest, err := mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO, waitingRedirect) + holdRequest, err := mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), waitingRedirect) if !holdRequest { addCmdMetrics(cmd, mgr.ServerAddr(), startTime) } @@ -310,7 +312,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e if waitingRedirect && holdRequest { mgr.tryRedirect(ctx) // Execute the held request no matter redirection succeeds or not. - _, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO, false) + _, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), false) addCmdMetrics(cmd, mgr.ServerAddr(), startTime) if err != nil && !IsMySQLError(err) { return err @@ -348,10 +350,10 @@ func (mgr *BackendConnManager) initSessionStates(backendIO *pnet.PacketIO, sessi return err } -func (mgr *BackendConnManager) querySessionStates() (sessionStates, sessionToken string, err error) { +func (mgr *BackendConnManager) querySessionStates(backendIO *pnet.PacketIO) (sessionStates, sessionToken string, err error) { // Do not lock here because the caller already locks. var result *gomysql.Result - if result, _, err = mgr.cmdProcessor.query(mgr.backendIO, sqlQueryState); err != nil { + if result, _, err = mgr.cmdProcessor.query(backendIO, sqlQueryState); err != nil { return } if sessionStates, err = result.GetStringByName(0, sessionStatesCol); err != nil { @@ -415,8 +417,9 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { // - Avoid the risk of deadlock mgr.redirectResCh <- rs }() + backendIO := mgr.backendIO.Load() var sessionStates, sessionToken string - if sessionStates, sessionToken, rs.err = mgr.querySessionStates(); rs.err != nil { + if sessionStates, sessionToken, rs.err = mgr.querySessionStates(backendIO); rs.err != nil { return } if rs.err = mgr.updateAuthInfoFromSessionStates(hack.Slice(sessionStates)); rs.err != nil { @@ -442,11 +445,11 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { } return } - if ignoredErr := mgr.backendIO.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) { + if ignoredErr := backendIO.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) { mgr.logger.Error("close previous backend connection failed", zap.Error(ignoredErr)) } - mgr.backendIO = newBackendIO + mgr.backendIO.Store(newBackendIO) mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil) } @@ -538,9 +541,10 @@ func (mgr *BackendConnManager) checkBackendActive() { mgr.processLock.Lock() defer mgr.processLock.Unlock() - if !mgr.backendIO.IsPeerActive() { + backendIO := mgr.backendIO.Load() + if !backendIO.IsPeerActive() { mgr.logger.Info("backend connection is closed, close client connection", zap.Stringer("client", mgr.clientIO.RemoteAddr()), - zap.Stringer("backend", mgr.backendIO.RemoteAddr())) + zap.Stringer("backend", backendIO.RemoteAddr())) if err := mgr.clientIO.GracefulClose(); err != nil { mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(err)) } @@ -566,10 +570,10 @@ func (mgr *BackendConnManager) ClientAddr() string { } func (mgr *BackendConnManager) ServerAddr() string { - if mgr.backendIO == nil { - return "" + if backendIO := mgr.backendIO.Load(); backendIO != nil { + return backendIO.RemoteAddr().String() } - return mgr.backendIO.RemoteAddr().String() + return "" } func (mgr *BackendConnManager) ClientInBytes() uint64 { @@ -613,10 +617,9 @@ func (mgr *BackendConnManager) Close() error { var connErr error var addr string mgr.processLock.Lock() - if mgr.backendIO != nil { - addr = mgr.ServerAddr() - connErr = mgr.backendIO.Close() - mgr.backendIO = nil + if backendIO := mgr.backendIO.Swap(nil); backendIO != nil { + addr = backendIO.RemoteAddr().String() + connErr = backendIO.Close() } mgr.processLock.Unlock() diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 89a267fe..bcf6d6bf 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -181,32 +181,32 @@ func (ts *backendMgrTester) startTxn4Backend(packetIO *pnet.PacketIO) error { func (ts *backendMgrTester) checkNotRedirected4Proxy(clientIO, backendIO *pnet.PacketIO) error { signal := (*signalRedirect)(atomic.LoadPointer(&ts.mp.signal)) require.Nil(ts.t, signal) - backend1 := ts.mp.backendIO + backend1 := ts.mp.backendIO.Load() // There is no other way to verify it's not redirected. // The buffer size of channel signalReceived is 0, so after the second redirect signal is sent, // we can ensure that the first signal is already processed. ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.signalReceived <- signalTypeRedirect // The backend connection is still the same. - require.Equal(ts.t, backend1, ts.mp.backendIO) + require.Equal(ts.t, backend1, ts.mp.backendIO.Load()) return nil } func (ts *backendMgrTester) redirectAfterCmd4Proxy(clientIO, backendIO *pnet.PacketIO) error { - backend1 := ts.mp.backendIO + backend1 := ts.mp.backendIO.Load() err := ts.forwardCmd4Proxy(clientIO, backendIO) require.NoError(ts.t, err) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventSucceed) - require.NotEqual(ts.t, backend1, ts.mp.backendIO) + require.NotEqual(ts.t, backend1, ts.mp.backendIO.Load()) require.Len(ts.t, ts.mp.GetRedirectingAddr(), 0) return nil } func (ts *backendMgrTester) redirectFail4Proxy(clientIO, backendIO *pnet.PacketIO) error { - backend1 := ts.mp.backendIO + backend1 := ts.mp.backendIO.Load() ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventFail) - require.Equal(ts.t, backend1, ts.mp.backendIO) + require.Equal(ts.t, backend1, ts.mp.backendIO.Load()) require.Len(ts.t, ts.mp.GetRedirectingAddr(), 0) return nil } @@ -244,10 +244,10 @@ func TestNormalRedirect(t *testing.T) { { client: nil, proxy: func(_, _ *pnet.PacketIO) error { - backend1 := ts.mp.backendIO + backend1 := ts.mp.backendIO.Load() ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed) - require.NotEqual(t, backend1, ts.mp.backendIO) + require.NotEqual(t, backend1, ts.mp.backendIO.Load()) return nil }, backend: ts.redirectSucceed4Backend, @@ -347,11 +347,11 @@ func TestRedirectInTxn(t *testing.T) { return ts.mc.request(packetIO) }, proxy: func(clientIO, backendIO *pnet.PacketIO) error { - backend1 := ts.mp.backendIO + backend1 := ts.mp.backendIO.Load() err := ts.forwardCmd4Proxy(clientIO, backendIO) require.NoError(t, err) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventFail) - require.Equal(t, backend1, ts.mp.backendIO) + require.Equal(t, backend1, ts.mp.backendIO.Load()) return nil }, backend: func(packetIO *pnet.PacketIO) error { @@ -495,10 +495,10 @@ func TestSpecialCmds(t *testing.T) { { client: nil, proxy: func(_, _ *pnet.PacketIO) error { - backend1 := ts.mp.backendIO + backend1 := ts.mp.backendIO.Load() ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed) - require.NotEqual(t, backend1, ts.mp.backendIO) + require.NotEqual(t, backend1, ts.mp.backendIO.Load()) return nil }, backend: func(packetIO *pnet.PacketIO) error { @@ -591,10 +591,10 @@ func TestCustomHandshake(t *testing.T) { { client: nil, proxy: func(_, _ *pnet.PacketIO) error { - backend1 := ts.mp.backendIO + backend1 := ts.mp.backendIO.Load() ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed) - require.NotEqual(t, backend1, ts.mp.backendIO) + require.NotEqual(t, backend1, ts.mp.backendIO.Load()) return nil }, backend: ts.redirectSucceed4Backend, diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 054da3b5..f6d50d2a 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -41,13 +41,13 @@ import ( "crypto/tls" "io" "net" + "sync/atomic" "time" "github.com/pingcap/TiProxy/lib/util/errors" "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/util/dbterror" - "go.uber.org/atomic" ) var ( @@ -75,11 +75,12 @@ func (f *rdbufConn) Read(b []byte) (int, error) { // PacketIO is a helper to read and write sql and proxy protocol. type PacketIO struct { - inBytes uint64 - outBytes uint64 - conn net.Conn + inBytes uint64 + outBytes uint64 + // conn is written during TLS handshake and read during setting keep alive concurrently. + conn atomic.Pointer[net.Conn] buf *bufio.ReadWriter - proxyInited *atomic.Bool + proxyInited atomic.Bool proxy *Proxy remoteAddr net.Addr wrap error @@ -92,15 +93,16 @@ func NewPacketIO(conn net.Conn, opts ...PacketIOption) *PacketIO { bufio.NewWriterSize(conn, defaultWriterSize), ) p := &PacketIO{ - conn: &rdbufConn{ - conn, - buf.Reader, - }, sequence: 0, - // TODO: disable it by default now - proxyInited: atomic.NewBool(true), - buf: buf, + buf: buf, } + // TODO: disable it by default now + p.proxyInited.Store(true) + cn := (net.Conn)(&rdbufConn{ + conn, + buf.Reader, + }) + p.conn.Store(&cn) for _, opt := range opts { opt(p) } @@ -120,14 +122,14 @@ func (p *PacketIO) Proxy() *Proxy { } func (p *PacketIO) LocalAddr() net.Addr { - return p.conn.LocalAddr() + return (*p.conn.Load()).LocalAddr() } func (p *PacketIO) RemoteAddr() net.Addr { if p.remoteAddr != nil { return p.remoteAddr } - return p.conn.RemoteAddr() + return (*p.conn.Load()).RemoteAddr() } func (p *PacketIO) ResetSequence() { @@ -256,7 +258,7 @@ func (p *PacketIO) OutBytes() uint64 { } func (p *PacketIO) TLSConnectionState() tls.ConnectionState { - if tlsConn, ok := p.conn.(*tls.Conn); ok { + if tlsConn, ok := (*p.conn.Load()).(*tls.Conn); ok { return tlsConn.ConnectionState() } return tls.ConnectionState{} @@ -274,21 +276,21 @@ func (p *PacketIO) Flush() error { // This function normally costs 1ms, so don't call it too frequently. // This function may incorrectly return true if the system is extremely slow. func (p *PacketIO) IsPeerActive() bool { - if err := p.conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil { + if err := (*p.conn.Load()).SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil { return false } active := true if _, err := p.buf.Peek(1); err != nil { active = !errors.Is(err, io.EOF) } - if err := p.conn.SetReadDeadline(time.Time{}); err != nil { + if err := (*p.conn.Load()).SetReadDeadline(time.Time{}); err != nil { return false } return active } func (p *PacketIO) GracefulClose() error { - if err := p.conn.SetDeadline(time.Now()); err != nil && !errors.Is(err, net.ErrClosed) { + if err := (*p.conn.Load()).SetDeadline(time.Now()); err != nil && !errors.Is(err, net.ErrClosed) { return err } return nil @@ -302,7 +304,7 @@ func (p *PacketIO) Close() error { errs = append(errs, err) } */ - if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + if err := (*p.conn.Load()).Close(); err != nil && !errors.Is(err, net.ErrClosed) { errs = append(errs, err) } return p.wrapErr(errors.Collect(ErrCloseConn, errs...)) diff --git a/pkg/proxy/net/packetio_options.go b/pkg/proxy/net/packetio_options.go index eb570199..55129692 100644 --- a/pkg/proxy/net/packetio_options.go +++ b/pkg/proxy/net/packetio_options.go @@ -16,14 +16,12 @@ package net import ( "net" - - "go.uber.org/atomic" ) type PacketIOption = func(*PacketIO) func WithProxy(pi *PacketIO) { - pi.proxyInited = atomic.NewBool(true) + pi.proxyInited.Store(true) } func WithWrapError(err error) func(pi *PacketIO) { diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index 92456eba..c7583cf3 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -148,7 +148,7 @@ func TestPacketIO(t *testing.T) { func TestTLS(t *testing.T) { stls, ctls, err := security.CreateTLSConfigForTest() require.NoError(t, err) - message := []byte("hello wolrd") + message := []byte("hello world") testTCPConn(t, func(t *testing.T, cli *PacketIO) { data, err := cli.ReadPacket() diff --git a/pkg/proxy/net/proxy_test.go b/pkg/proxy/net/proxy_test.go index ab38708d..2126fc55 100644 --- a/pkg/proxy/net/proxy_test.go +++ b/pkg/proxy/net/proxy_test.go @@ -48,7 +48,7 @@ func TestProxyParse(t *testing.T) { func(t *testing.T, srv *PacketIO) { // skip 4 bytes of magic var hdr [4]byte - _, err := io.ReadFull(srv.conn, hdr[:]) + _, err := io.ReadFull(*srv.conn.Load(), hdr[:]) require.NoError(t, err) // try to parse V2 diff --git a/pkg/proxy/net/tls.go b/pkg/proxy/net/tls.go index 0e5ecb01..2d3e7c98 100644 --- a/pkg/proxy/net/tls.go +++ b/pkg/proxy/net/tls.go @@ -17,32 +17,35 @@ package net import ( "bufio" "crypto/tls" + "net" "github.com/pingcap/TiProxy/lib/util/errors" ) func (p *PacketIO) ServerTLSHandshake(tlsConfig *tls.Config) (tls.ConnectionState, error) { tlsConfig = tlsConfig.Clone() - tlsConn := tls.Server(p.conn, tlsConfig) + tlsConn := tls.Server(*p.conn.Load(), tlsConfig) if err := tlsConn.Handshake(); err != nil { return tls.ConnectionState{}, p.wrapErr(errors.Wrap(ErrHandshakeTLS, err)) } - p.conn = tlsConn - p.buf.Writer.Reset(p.conn) + conn := (net.Conn)(tlsConn) + p.conn.Store(&conn) + p.buf.Writer.Reset(conn) // Wrap it with another buffer to enable Peek. - p.buf = bufio.NewReadWriter(bufio.NewReaderSize(p.conn, defaultReaderSize), p.buf.Writer) + p.buf = bufio.NewReadWriter(bufio.NewReaderSize(conn, defaultReaderSize), p.buf.Writer) return tlsConn.ConnectionState(), nil } func (p *PacketIO) ClientTLSHandshake(tlsConfig *tls.Config) error { tlsConfig = tlsConfig.Clone() - tlsConn := tls.Client(p.conn, tlsConfig) + tlsConn := tls.Client(*p.conn.Load(), tlsConfig) if err := tlsConn.Handshake(); err != nil { return errors.WithStack(errors.Wrap(ErrHandshakeTLS, err)) } - p.conn = tlsConn - p.buf.Writer.Reset(p.conn) + conn := (net.Conn)(tlsConn) + p.conn.Store(&conn) + p.buf.Writer.Reset(conn) // Wrap it with another buffer to enable Peek. - p.buf = bufio.NewReadWriter(bufio.NewReaderSize(p.conn, defaultReaderSize), p.buf.Writer) + p.buf = bufio.NewReadWriter(bufio.NewReaderSize(conn, defaultReaderSize), p.buf.Writer) return nil }