Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend, net: make connection concurrency safe #232

Merged
merged 5 commits into from
Mar 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ type BackendConnManager struct {
closeStatus atomic.Int32
checkBackendTicker *time.Ticker
// cancelFunc is used to cancel the signal processing goroutine.
cancelFunc context.CancelFunc
clientIO *pnet.PacketIO
backendIO *pnet.PacketIO
cancelFunc context.CancelFunc
clientIO *pnet.PacketIO
// backendIO may be written during redirection and be read in ExecuteCmd/Redirect/setKeepalive.
backendIO atomic.Pointer[pnet.PacketIO]
backendTLS *tls.Config
handshakeHandler HandshakeHandler
ctxmap sync.Map
Expand Down Expand Up @@ -228,8 +229,9 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato
// NOTE: should use DNS name as much as possible
// Usually certs are signed with domain instead of IP addrs
// And `RemoteAddr()` will return IP addr
mgr.backendIO = pnet.NewPacketIO(cn, pnet.WithRemoteAddr(addr, cn.RemoteAddr()))
return mgr.backendIO, nil
backendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(addr, cn.RemoteAddr()))
mgr.backendIO.Store(backendIO)
return backendIO, nil
},
backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx),
func(err error, d time.Duration) {
Expand All @@ -245,7 +247,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato
mgr.logger.Error("get backend failed", zap.Duration("duration", duration), zap.NamedError("last_err", origErr))
} else if duration >= 3*time.Second {
mgr.logger.Warn("get backend slow", zap.Duration("duration", duration), zap.NamedError("last_err", origErr),
zap.Stringer("backend_addr", mgr.backendIO.RemoteAddr()))
zap.String("backend_addr", mgr.ServerAddr()))
}
if err != nil && errors.Is(err, context.DeadlineExceeded) {
if origErr != nil {
Expand All @@ -272,7 +274,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e
}
defer mgr.resetCheckBackendTicker()
waitingRedirect := atomic.LoadPointer(&mgr.signal) != nil
holdRequest, err := mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO, waitingRedirect)
holdRequest, err := mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), waitingRedirect)
if !holdRequest {
addCmdMetrics(cmd, mgr.ServerAddr(), startTime)
}
Expand Down Expand Up @@ -310,7 +312,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e
if waitingRedirect && holdRequest {
mgr.tryRedirect(ctx)
// Execute the held request no matter redirection succeeds or not.
_, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO, false)
_, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), false)
addCmdMetrics(cmd, mgr.ServerAddr(), startTime)
if err != nil && !IsMySQLError(err) {
return err
Expand Down Expand Up @@ -348,10 +350,10 @@ func (mgr *BackendConnManager) initSessionStates(backendIO *pnet.PacketIO, sessi
return err
}

func (mgr *BackendConnManager) querySessionStates() (sessionStates, sessionToken string, err error) {
func (mgr *BackendConnManager) querySessionStates(backendIO *pnet.PacketIO) (sessionStates, sessionToken string, err error) {
// Do not lock here because the caller already locks.
var result *gomysql.Result
if result, _, err = mgr.cmdProcessor.query(mgr.backendIO, sqlQueryState); err != nil {
if result, _, err = mgr.cmdProcessor.query(backendIO, sqlQueryState); err != nil {
return
}
if sessionStates, err = result.GetStringByName(0, sessionStatesCol); err != nil {
Expand Down Expand Up @@ -415,8 +417,9 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) {
// - Avoid the risk of deadlock
mgr.redirectResCh <- rs
}()
backendIO := mgr.backendIO.Load()
var sessionStates, sessionToken string
if sessionStates, sessionToken, rs.err = mgr.querySessionStates(); rs.err != nil {
if sessionStates, sessionToken, rs.err = mgr.querySessionStates(backendIO); rs.err != nil {
return
}
if rs.err = mgr.updateAuthInfoFromSessionStates(hack.Slice(sessionStates)); rs.err != nil {
Expand All @@ -442,11 +445,11 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) {
}
return
}
if ignoredErr := mgr.backendIO.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) {
if ignoredErr := backendIO.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) {
mgr.logger.Error("close previous backend connection failed", zap.Error(ignoredErr))
}

mgr.backendIO = newBackendIO
mgr.backendIO.Store(newBackendIO)
mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil)
}

Expand Down Expand Up @@ -538,9 +541,10 @@ func (mgr *BackendConnManager) checkBackendActive() {

mgr.processLock.Lock()
defer mgr.processLock.Unlock()
if !mgr.backendIO.IsPeerActive() {
backendIO := mgr.backendIO.Load()
if !backendIO.IsPeerActive() {
mgr.logger.Info("backend connection is closed, close client connection", zap.Stringer("client", mgr.clientIO.RemoteAddr()),
zap.Stringer("backend", mgr.backendIO.RemoteAddr()))
zap.Stringer("backend", backendIO.RemoteAddr()))
if err := mgr.clientIO.GracefulClose(); err != nil {
mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(err))
}
Expand All @@ -566,10 +570,10 @@ func (mgr *BackendConnManager) ClientAddr() string {
}

func (mgr *BackendConnManager) ServerAddr() string {
if mgr.backendIO == nil {
return ""
if backendIO := mgr.backendIO.Load(); backendIO != nil {
return backendIO.RemoteAddr().String()
}
return mgr.backendIO.RemoteAddr().String()
return ""
}

func (mgr *BackendConnManager) ClientInBytes() uint64 {
Expand Down Expand Up @@ -613,10 +617,9 @@ func (mgr *BackendConnManager) Close() error {
var connErr error
var addr string
mgr.processLock.Lock()
if mgr.backendIO != nil {
addr = mgr.ServerAddr()
connErr = mgr.backendIO.Close()
mgr.backendIO = nil
if backendIO := mgr.backendIO.Swap(nil); backendIO != nil {
addr = backendIO.RemoteAddr().String()
connErr = backendIO.Close()
}
mgr.processLock.Unlock()

Expand Down
28 changes: 14 additions & 14 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,32 +181,32 @@ func (ts *backendMgrTester) startTxn4Backend(packetIO *pnet.PacketIO) error {
func (ts *backendMgrTester) checkNotRedirected4Proxy(clientIO, backendIO *pnet.PacketIO) error {
signal := (*signalRedirect)(atomic.LoadPointer(&ts.mp.signal))
require.Nil(ts.t, signal)
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
// There is no other way to verify it's not redirected.
// The buffer size of channel signalReceived is 0, so after the second redirect signal is sent,
// we can ensure that the first signal is already processed.
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.signalReceived <- signalTypeRedirect
// The backend connection is still the same.
require.Equal(ts.t, backend1, ts.mp.backendIO)
require.Equal(ts.t, backend1, ts.mp.backendIO.Load())
return nil
}

func (ts *backendMgrTester) redirectAfterCmd4Proxy(clientIO, backendIO *pnet.PacketIO) error {
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
err := ts.forwardCmd4Proxy(clientIO, backendIO)
require.NoError(ts.t, err)
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventSucceed)
require.NotEqual(ts.t, backend1, ts.mp.backendIO)
require.NotEqual(ts.t, backend1, ts.mp.backendIO.Load())
require.Len(ts.t, ts.mp.GetRedirectingAddr(), 0)
return nil
}

func (ts *backendMgrTester) redirectFail4Proxy(clientIO, backendIO *pnet.PacketIO) error {
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventFail)
require.Equal(ts.t, backend1, ts.mp.backendIO)
require.Equal(ts.t, backend1, ts.mp.backendIO.Load())
require.Len(ts.t, ts.mp.GetRedirectingAddr(), 0)
return nil
}
Expand Down Expand Up @@ -244,10 +244,10 @@ func TestNormalRedirect(t *testing.T) {
{
client: nil,
proxy: func(_, _ *pnet.PacketIO) error {
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed)
require.NotEqual(t, backend1, ts.mp.backendIO)
require.NotEqual(t, backend1, ts.mp.backendIO.Load())
return nil
},
backend: ts.redirectSucceed4Backend,
Expand Down Expand Up @@ -347,11 +347,11 @@ func TestRedirectInTxn(t *testing.T) {
return ts.mc.request(packetIO)
},
proxy: func(clientIO, backendIO *pnet.PacketIO) error {
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
err := ts.forwardCmd4Proxy(clientIO, backendIO)
require.NoError(t, err)
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventFail)
require.Equal(t, backend1, ts.mp.backendIO)
require.Equal(t, backend1, ts.mp.backendIO.Load())
return nil
},
backend: func(packetIO *pnet.PacketIO) error {
Expand Down Expand Up @@ -495,10 +495,10 @@ func TestSpecialCmds(t *testing.T) {
{
client: nil,
proxy: func(_, _ *pnet.PacketIO) error {
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed)
require.NotEqual(t, backend1, ts.mp.backendIO)
require.NotEqual(t, backend1, ts.mp.backendIO.Load())
return nil
},
backend: func(packetIO *pnet.PacketIO) error {
Expand Down Expand Up @@ -591,10 +591,10 @@ func TestCustomHandshake(t *testing.T) {
{
client: nil,
proxy: func(_, _ *pnet.PacketIO) error {
backend1 := ts.mp.backendIO
backend1 := ts.mp.backendIO.Load()
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed)
require.NotEqual(t, backend1, ts.mp.backendIO)
require.NotEqual(t, backend1, ts.mp.backendIO.Load())
return nil
},
backend: ts.redirectSucceed4Backend,
Expand Down
40 changes: 21 additions & 19 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ import (
"crypto/tls"
"io"
"net"
"sync/atomic"
"time"

"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/util/dbterror"
"go.uber.org/atomic"
)

var (
Expand Down Expand Up @@ -75,11 +75,12 @@ func (f *rdbufConn) Read(b []byte) (int, error) {

// PacketIO is a helper to read and write sql and proxy protocol.
type PacketIO struct {
inBytes uint64
outBytes uint64
conn net.Conn
inBytes uint64
outBytes uint64
// conn is written during TLS handshake and read during setting keep alive concurrently.
conn atomic.Pointer[net.Conn]
buf *bufio.ReadWriter
proxyInited *atomic.Bool
proxyInited atomic.Bool
proxy *Proxy
remoteAddr net.Addr
wrap error
Expand All @@ -92,15 +93,16 @@ func NewPacketIO(conn net.Conn, opts ...PacketIOption) *PacketIO {
bufio.NewWriterSize(conn, defaultWriterSize),
)
p := &PacketIO{
conn: &rdbufConn{
conn,
buf.Reader,
},
sequence: 0,
// TODO: disable it by default now
proxyInited: atomic.NewBool(true),
buf: buf,
buf: buf,
}
// TODO: disable it by default now
p.proxyInited.Store(true)
cn := (net.Conn)(&rdbufConn{
conn,
buf.Reader,
})
p.conn.Store(&cn)
for _, opt := range opts {
opt(p)
}
Expand All @@ -120,14 +122,14 @@ func (p *PacketIO) Proxy() *Proxy {
}

func (p *PacketIO) LocalAddr() net.Addr {
return p.conn.LocalAddr()
return (*p.conn.Load()).LocalAddr()
}

func (p *PacketIO) RemoteAddr() net.Addr {
if p.remoteAddr != nil {
return p.remoteAddr
}
return p.conn.RemoteAddr()
return (*p.conn.Load()).RemoteAddr()
}

func (p *PacketIO) ResetSequence() {
Expand Down Expand Up @@ -256,7 +258,7 @@ func (p *PacketIO) OutBytes() uint64 {
}

func (p *PacketIO) TLSConnectionState() tls.ConnectionState {
if tlsConn, ok := p.conn.(*tls.Conn); ok {
if tlsConn, ok := (*p.conn.Load()).(*tls.Conn); ok {
return tlsConn.ConnectionState()
}
return tls.ConnectionState{}
Expand All @@ -274,21 +276,21 @@ func (p *PacketIO) Flush() error {
// This function normally costs 1ms, so don't call it too frequently.
// This function may incorrectly return true if the system is extremely slow.
func (p *PacketIO) IsPeerActive() bool {
if err := p.conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil {
if err := (*p.conn.Load()).SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil {
return false
}
active := true
if _, err := p.buf.Peek(1); err != nil {
active = !errors.Is(err, io.EOF)
}
if err := p.conn.SetReadDeadline(time.Time{}); err != nil {
if err := (*p.conn.Load()).SetReadDeadline(time.Time{}); err != nil {
return false
}
return active
}

func (p *PacketIO) GracefulClose() error {
if err := p.conn.SetDeadline(time.Now()); err != nil && !errors.Is(err, net.ErrClosed) {
if err := (*p.conn.Load()).SetDeadline(time.Now()); err != nil && !errors.Is(err, net.ErrClosed) {
return err
}
return nil
Expand All @@ -302,7 +304,7 @@ func (p *PacketIO) Close() error {
errs = append(errs, err)
}
*/
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
if err := (*p.conn.Load()).Close(); err != nil && !errors.Is(err, net.ErrClosed) {
errs = append(errs, err)
}
return p.wrapErr(errors.Collect(ErrCloseConn, errs...))
Expand Down
4 changes: 1 addition & 3 deletions pkg/proxy/net/packetio_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ package net

import (
"net"

"go.uber.org/atomic"
)

type PacketIOption = func(*PacketIO)

func WithProxy(pi *PacketIO) {
pi.proxyInited = atomic.NewBool(true)
pi.proxyInited.Store(true)
}

func WithWrapError(err error) func(pi *PacketIO) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/net/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func TestPacketIO(t *testing.T) {
func TestTLS(t *testing.T) {
stls, ctls, err := security.CreateTLSConfigForTest()
require.NoError(t, err)
message := []byte("hello wolrd")
message := []byte("hello world")
testTCPConn(t,
func(t *testing.T, cli *PacketIO) {
data, err := cli.ReadPacket()
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/net/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestProxyParse(t *testing.T) {
func(t *testing.T, srv *PacketIO) {
// skip 4 bytes of magic
var hdr [4]byte
_, err := io.ReadFull(srv.conn, hdr[:])
_, err := io.ReadFull(*srv.conn.Load(), hdr[:])
require.NoError(t, err)

// try to parse V2
Expand Down
Loading