Skip to content

Commit

Permalink
backend, net: make connection concurrency safe (pingcap#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored and xhebox committed Mar 13, 2023
1 parent 6296a67 commit 9a3d280
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 68 deletions.
47 changes: 25 additions & 22 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ type BackendConnManager struct {
closeStatus atomic.Int32
checkBackendTicker *time.Ticker
// cancelFunc is used to cancel the signal processing goroutine.
cancelFunc context.CancelFunc
clientIO *pnet.PacketIO
backendIO *pnet.PacketIO
cancelFunc context.CancelFunc
clientIO *pnet.PacketIO
// backendIO may be written during redirection and be read in ExecuteCmd/Redirect/setKeepalive.
backendIO atomic.Pointer[pnet.PacketIO]
backendTLS *tls.Config
handshakeHandler HandshakeHandler
ctxmap sync.Map
Expand Down Expand Up @@ -228,8 +229,9 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato
// NOTE: should use DNS name as much as possible
// Usually certs are signed with domain instead of IP addrs
// And `RemoteAddr()` will return IP addr
mgr.backendIO = pnet.NewPacketIO(cn, pnet.WithRemoteAddr(addr, cn.RemoteAddr()))
return mgr.backendIO, nil
backendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(addr, cn.RemoteAddr()))
mgr.backendIO.Store(backendIO)
return backendIO, nil
},
backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx),
func(err error, d time.Duration) {
Expand All @@ -245,7 +247,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato
mgr.logger.Error("get backend failed", zap.Duration("duration", duration), zap.NamedError("last_err", origErr))
} else if duration >= 3*time.Second {
mgr.logger.Warn("get backend slow", zap.Duration("duration", duration), zap.NamedError("last_err", origErr),
zap.Stringer("backend_addr", mgr.backendIO.RemoteAddr()))
zap.String("backend_addr", mgr.ServerAddr()))
}
if err != nil && errors.Is(err, context.DeadlineExceeded) {
if origErr != nil {
Expand All @@ -272,7 +274,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e
}
defer mgr.resetCheckBackendTicker()
waitingRedirect := atomic.LoadPointer(&mgr.signal) != nil
holdRequest, err := mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO, waitingRedirect)
holdRequest, err := mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), waitingRedirect)
if !holdRequest {
addCmdMetrics(cmd, mgr.ServerAddr(), startTime)
}
Expand Down Expand Up @@ -310,7 +312,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e
if waitingRedirect && holdRequest {
mgr.tryRedirect(ctx)
// Execute the held request no matter redirection succeeds or not.
_, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO, false)
_, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), false)
addCmdMetrics(cmd, mgr.ServerAddr(), startTime)
if err != nil && !IsMySQLError(err) {
return err
Expand Down Expand Up @@ -348,10 +350,10 @@ func (mgr *BackendConnManager) initSessionStates(backendIO *pnet.PacketIO, sessi
return err
}

func (mgr *BackendConnManager) querySessionStates() (sessionStates, sessionToken string, err error) {
func (mgr *BackendConnManager) querySessionStates(backendIO *pnet.PacketIO) (sessionStates, sessionToken string, err error) {
// Do not lock here because the caller already locks.
var result *gomysql.Result
if result, _, err = mgr.cmdProcessor.query(mgr.backendIO, sqlQueryState); err != nil {
if result, _, err = mgr.cmdProcessor.query(backendIO, sqlQueryState); err != nil {
return
}
if sessionStates, err = result.GetStringByName(0, sessionStatesCol); err != nil {
Expand Down Expand Up @@ -415,8 +417,9 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) {
// - Avoid the risk of deadlock
mgr.redirectResCh <- rs
}()
backendIO := mgr.backendIO.Load()
var sessionStates, sessionToken string
if sessionStates, sessionToken, rs.err = mgr.querySessionStates(); rs.err != nil {
if sessionStates, sessionToken, rs.err = mgr.querySessionStates(backendIO); rs.err != nil {
return
}
if rs.err = mgr.updateAuthInfoFromSessionStates(hack.Slice(sessionStates)); rs.err != nil {
Expand All @@ -442,11 +445,11 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) {
}
return
}
if ignoredErr := mgr.backendIO.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) {
if ignoredErr := backendIO.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) {
mgr.logger.Error("close previous backend connection failed", zap.Error(ignoredErr))
}

mgr.backendIO = newBackendIO
mgr.backendIO.Store(newBackendIO)
mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil)
}

Expand Down Expand Up @@ -538,9 +541,10 @@ func (mgr *BackendConnManager) checkBackendActive() {

mgr.processLock.Lock()
defer mgr.processLock.Unlock()
if !mgr.backendIO.IsPeerActive() {
backendIO := mgr.backendIO.Load()
if !backendIO.IsPeerActive() {
mgr.logger.Info("backend connection is closed, close client connection", zap.Stringer("client", mgr.clientIO.RemoteAddr()),
zap.Stringer("backend", mgr.backendIO.RemoteAddr()))
zap.Stringer("backend", backendIO.RemoteAddr()))
if err := mgr.clientIO.GracefulClose(); err != nil {
mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(err))
}
Expand All @@ -566,10 +570,10 @@ func (mgr *BackendConnManager) ClientAddr() string {
}

func (mgr *BackendConnManager) ServerAddr() string {
if mgr.backendIO == nil {
return ""
if backendIO := mgr.backendIO.Load(); backendIO != nil {
return backendIO.RemoteAddr().String()
}
return mgr.backendIO.RemoteAddr().String()
return ""
}

func (mgr *BackendConnManager) ClientInBytes() uint64 {
Expand Down Expand Up @@ -613,10 +617,9 @@ func (mgr *BackendConnManager) Close() error {
var connErr error
var addr string
mgr.processLock.Lock()
if mgr.backendIO != nil {
addr = mgr.ServerAddr()
connErr = mgr.backendIO.Close()
mgr.backendIO = nil
if backendIO := mgr.backendIO.Swap(nil); backendIO != nil {
addr = backendIO.RemoteAddr().String()
connErr = backendIO.Close()
}
mgr.processLock.Unlock()

Expand Down
28 changes: 14 additions & 14 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,32 +181,32 @@ func (ts *backendMgrTester) startTxn4Backend(packetIO *pnet.PacketIO) error {
func (ts *backendMgrTester) checkNotRedirected4Proxy(clientIO, backendIO *pnet.PacketIO) error {
signal := (*signalRedirect)(atomic.LoadPointer(&ts.mp.signal))
require.Nil(ts.t, signal)
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
// There is no other way to verify it's not redirected.
// The buffer size of channel signalReceived is 0, so after the second redirect signal is sent,
// we can ensure that the first signal is already processed.
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.signalReceived <- signalTypeRedirect
// The backend connection is still the same.
require.Equal(ts.t, backend1, ts.mp.backendIO)
require.Equal(ts.t, backend1, ts.mp.backendIO.Load())
return nil
}

func (ts *backendMgrTester) redirectAfterCmd4Proxy(clientIO, backendIO *pnet.PacketIO) error {
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
err := ts.forwardCmd4Proxy(clientIO, backendIO)
require.NoError(ts.t, err)
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventSucceed)
require.NotEqual(ts.t, backend1, ts.mp.backendIO)
require.NotEqual(ts.t, backend1, ts.mp.backendIO.Load())
require.Len(ts.t, ts.mp.GetRedirectingAddr(), 0)
return nil
}

func (ts *backendMgrTester) redirectFail4Proxy(clientIO, backendIO *pnet.PacketIO) error {
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventFail)
require.Equal(ts.t, backend1, ts.mp.backendIO)
require.Equal(ts.t, backend1, ts.mp.backendIO.Load())
require.Len(ts.t, ts.mp.GetRedirectingAddr(), 0)
return nil
}
Expand Down Expand Up @@ -244,10 +244,10 @@ func TestNormalRedirect(t *testing.T) {
{
client: nil,
proxy: func(_, _ *pnet.PacketIO) error {
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed)
require.NotEqual(t, backend1, ts.mp.backendIO)
require.NotEqual(t, backend1, ts.mp.backendIO.Load())
return nil
},
backend: ts.redirectSucceed4Backend,
Expand Down Expand Up @@ -347,11 +347,11 @@ func TestRedirectInTxn(t *testing.T) {
return ts.mc.request(packetIO)
},
proxy: func(clientIO, backendIO *pnet.PacketIO) error {
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
err := ts.forwardCmd4Proxy(clientIO, backendIO)
require.NoError(t, err)
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventFail)
require.Equal(t, backend1, ts.mp.backendIO)
require.Equal(t, backend1, ts.mp.backendIO.Load())
return nil
},
backend: func(packetIO *pnet.PacketIO) error {
Expand Down Expand Up @@ -495,10 +495,10 @@ func TestSpecialCmds(t *testing.T) {
{
client: nil,
proxy: func(_, _ *pnet.PacketIO) error {
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed)
require.NotEqual(t, backend1, ts.mp.backendIO)
require.NotEqual(t, backend1, ts.mp.backendIO.Load())
return nil
},
backend: func(packetIO *pnet.PacketIO) error {
Expand Down Expand Up @@ -591,10 +591,10 @@ func TestCustomHandshake(t *testing.T) {
{
client: nil,
proxy: func(_, _ *pnet.PacketIO) error {
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed)
require.NotEqual(t, backend1, ts.mp.backendIO)
require.NotEqual(t, backend1, ts.mp.backendIO.Load())
return nil
},
backend: ts.redirectSucceed4Backend,
Expand Down
40 changes: 21 additions & 19 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ import (
"crypto/tls"
"io"
"net"
"sync/atomic"
"time"

"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/util/dbterror"
"go.uber.org/atomic"
)

var (
Expand Down Expand Up @@ -75,11 +75,12 @@ func (f *rdbufConn) Read(b []byte) (int, error) {

// PacketIO is a helper to read and write sql and proxy protocol.
type PacketIO struct {
inBytes uint64
outBytes uint64
conn net.Conn
inBytes uint64
outBytes uint64
// conn is written during TLS handshake and read during setting keep alive concurrently.
conn atomic.Pointer[net.Conn]
buf *bufio.ReadWriter
proxyInited *atomic.Bool
proxyInited atomic.Bool
proxy *Proxy
remoteAddr net.Addr
wrap error
Expand All @@ -92,15 +93,16 @@ func NewPacketIO(conn net.Conn, opts ...PacketIOption) *PacketIO {
bufio.NewWriterSize(conn, defaultWriterSize),
)
p := &PacketIO{
conn: &rdbufConn{
conn,
buf.Reader,
},
sequence: 0,
// TODO: disable it by default now
proxyInited: atomic.NewBool(true),
buf: buf,
buf: buf,
}
// TODO: disable it by default now
p.proxyInited.Store(true)
cn := (net.Conn)(&rdbufConn{
conn,
buf.Reader,
})
p.conn.Store(&cn)
for _, opt := range opts {
opt(p)
}
Expand All @@ -120,14 +122,14 @@ func (p *PacketIO) Proxy() *Proxy {
}

func (p *PacketIO) LocalAddr() net.Addr {
return p.conn.LocalAddr()
return (*p.conn.Load()).LocalAddr()
}

func (p *PacketIO) RemoteAddr() net.Addr {
if p.remoteAddr != nil {
return p.remoteAddr
}
return p.conn.RemoteAddr()
return (*p.conn.Load()).RemoteAddr()
}

func (p *PacketIO) ResetSequence() {
Expand Down Expand Up @@ -256,7 +258,7 @@ func (p *PacketIO) OutBytes() uint64 {
}

func (p *PacketIO) TLSConnectionState() tls.ConnectionState {
if tlsConn, ok := p.conn.(*tls.Conn); ok {
if tlsConn, ok := (*p.conn.Load()).(*tls.Conn); ok {
return tlsConn.ConnectionState()
}
return tls.ConnectionState{}
Expand All @@ -274,21 +276,21 @@ func (p *PacketIO) Flush() error {
// This function normally costs 1ms, so don't call it too frequently.
// This function may incorrectly return true if the system is extremely slow.
func (p *PacketIO) IsPeerActive() bool {
if err := p.conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil {
if err := (*p.conn.Load()).SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil {
return false
}
active := true
if _, err := p.buf.Peek(1); err != nil {
active = !errors.Is(err, io.EOF)
}
if err := p.conn.SetReadDeadline(time.Time{}); err != nil {
if err := (*p.conn.Load()).SetReadDeadline(time.Time{}); err != nil {
return false
}
return active
}

func (p *PacketIO) GracefulClose() error {
if err := p.conn.SetDeadline(time.Now()); err != nil && !errors.Is(err, net.ErrClosed) {
if err := (*p.conn.Load()).SetDeadline(time.Now()); err != nil && !errors.Is(err, net.ErrClosed) {
return err
}
return nil
Expand All @@ -302,7 +304,7 @@ func (p *PacketIO) Close() error {
errs = append(errs, err)
}
*/
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
if err := (*p.conn.Load()).Close(); err != nil && !errors.Is(err, net.ErrClosed) {
errs = append(errs, err)
}
return p.wrapErr(errors.Collect(ErrCloseConn, errs...))
Expand Down
4 changes: 1 addition & 3 deletions pkg/proxy/net/packetio_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ package net

import (
"net"

"go.uber.org/atomic"
)

type PacketIOption = func(*PacketIO)

func WithProxy(pi *PacketIO) {
pi.proxyInited = atomic.NewBool(true)
pi.proxyInited.Store(true)
}

func WithWrapError(err error) func(pi *PacketIO) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/net/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func TestPacketIO(t *testing.T) {
func TestTLS(t *testing.T) {
stls, ctls, err := security.CreateTLSConfigForTest()
require.NoError(t, err)
message := []byte("hello wolrd")
message := []byte("hello world")
testTCPConn(t,
func(t *testing.T, cli *PacketIO) {
data, err := cli.ReadPacket()
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/net/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestProxyParse(t *testing.T) {
func(t *testing.T, srv *PacketIO) {
// skip 4 bytes of magic
var hdr [4]byte
_, err := io.ReadFull(srv.conn, hdr[:])
_, err := io.ReadFull(*srv.conn.Load(), hdr[:])
require.NoError(t, err)

// try to parse V2
Expand Down
Loading

0 comments on commit 9a3d280

Please sign in to comment.