Skip to content

Commit

Permalink
backend, net: support connection attrs in COM_CHANGE_USER (#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Sep 12, 2023
1 parent eb5b4b9 commit 8fe7525
Show file tree
Hide file tree
Showing 10 changed files with 276 additions and 67 deletions.
8 changes: 4 additions & 4 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,10 @@ func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) erro
}

// changeUser is called once the client sends COM_CHANGE_USER.
func (auth *Authenticator) changeUser(username, db string) {
auth.user = username
auth.dbname = db
// TODO: attrs
func (auth *Authenticator) changeUser(req *pnet.ChangeUserReq) {
auth.user = req.User
auth.dbname = req.DB
auth.attrs = req.Attrs
}

// updateCurrentDB is called once the client sends COM_INIT_DB or `use db`.
Expand Down
11 changes: 5 additions & 6 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (

"github.com/cenkalti/backoff/v4"
gomysql "github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tiproxy/lib/config"
"github.com/pingcap/tiproxy/lib/util/errors"
"github.com/pingcap/tiproxy/lib/util/waitgroup"
Expand Down Expand Up @@ -130,7 +129,7 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler
logger: logger,
config: config,
connectionID: connectionID,
cmdProcessor: NewCmdProcessor(),
cmdProcessor: NewCmdProcessor(logger.Named("cp")),
handshakeHandler: handshakeHandler,
authenticator: &Authenticator{
proxyProtocol: config.ProxyProtocol,
Expand Down Expand Up @@ -258,7 +257,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) (
mgr.handshakeHandler.OnTraffic(mgr)
}()
if len(request) < 1 {
err = mysql.ErrMalformPacket
err = gomysql.ErrMalformPacket
return
}
cmd := pnet.Command(request[0])
Expand Down Expand Up @@ -302,9 +301,9 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) (
return
}
case pnet.ComChangeUser:
username, db := pnet.ParseChangeUser(request)
mgr.authenticator.changeUser(username, db)
return
// Critical errors should not happen because CmdProcessor has parsed it already.
req, _ := pnet.ParseChangeUser(request, mgr.authenticator.capability)
mgr.authenticator.changeUser(req)
}
}
// Even if it meets an MySQL error, it may have changed the status, such as when executing multi-statements.
Expand Down
124 changes: 99 additions & 25 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ func (ts *backendMgrTester) redirectSucceed4Backend(packetIO *pnet.PacketIO) err
return nil
}

func (ts *backendMgrTester) redirectSucceed4Proxy(_, _ *pnet.PacketIO) error {
backend1 := ts.mp.backendIO.Load()
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventSucceed)
require.NotEqual(ts.t, backend1, ts.mp.backendIO.Load())
require.Equal(ts.t, SrcClientQuit, ts.mp.QuitSource())
return nil
}

func (ts *backendMgrTester) forwardCmd4Proxy(clientIO, backendIO *pnet.PacketIO) error {
clientIO.ResetSequence()
request, err := clientIO.ReadPacket()
Expand Down Expand Up @@ -232,15 +241,8 @@ func TestNormalRedirect(t *testing.T) {
},
// 2nd handshake: redirect immediately after connection
{
client: nil,
proxy: func(_, _ *pnet.PacketIO) error {
backend1 := ts.mp.backendIO.Load()
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed)
require.NotEqual(t, backend1, ts.mp.backendIO.Load())
require.Equal(t, SrcClientQuit, ts.mp.QuitSource())
return nil
},
client: nil,
proxy: ts.redirectSucceed4Proxy,
backend: ts.redirectSucceed4Backend,
},
}
Expand Down Expand Up @@ -316,6 +318,33 @@ func TestRedirectInTxn(t *testing.T) {
{
proxy: ts.checkNotRedirected4Proxy,
},
// CHANGE_USER clears the txn
{
client: func(packetIO *pnet.PacketIO) error {
ts.mc.cmd = pnet.ComChangeUser
return ts.mc.request(packetIO)
},
proxy: ts.redirectAfterCmd4Proxy,
backend: func(packetIO *pnet.PacketIO) error {
// respond to the client request
err := ts.respondWithNoTxn4Backend(packetIO)
require.NoError(t, err)
return ts.redirectSucceed4Backend(packetIO)
},
},
// start a transaction to make it unredirect-able
{
client: func(packetIO *pnet.PacketIO) error {
ts.mc.cmd = pnet.ComQuery
return ts.mc.request(packetIO)
},
proxy: ts.forwardCmd4Proxy,
backend: ts.startTxn4Backend,
},
// try to redirect but it doesn't redirect
{
proxy: ts.checkNotRedirected4Proxy,
},
// internal COMMIT fails and the `begin` is not sent
{
client: func(packetIO *pnet.PacketIO) error {
Expand Down Expand Up @@ -492,14 +521,7 @@ func TestSpecialCmds(t *testing.T) {
// 2nd handshake
{
client: nil,
proxy: func(_, _ *pnet.PacketIO) error {
backend1 := ts.mp.backendIO.Load()
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed)
require.NotEqual(t, backend1, ts.mp.backendIO.Load())
require.Equal(t, SrcClientQuit, ts.mp.QuitSource())
return nil
},
proxy: ts.redirectSucceed4Proxy,
backend: func(packetIO *pnet.PacketIO) error {
ts.mb.sessionStates = "{\"current-db\":\"session_db\"}"
require.NoError(t, ts.redirectSucceed4Backend(packetIO))
Expand Down Expand Up @@ -596,14 +618,8 @@ func TestCustomHandshake(t *testing.T) {
},
// 2nd handshake
{
client: nil,
proxy: func(_, _ *pnet.PacketIO) error {
backend1 := ts.mp.backendIO.Load()
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed)
require.NotEqual(t, backend1, ts.mp.backendIO.Load())
return nil
},
client: nil,
proxy: ts.redirectSucceed4Proxy,
backend: ts.redirectSucceed4Backend,
},
{
Expand Down Expand Up @@ -1002,3 +1018,61 @@ func TestConnID(t *testing.T) {
ts.runTests(runners)
}
}

func TestConnAttrs(t *testing.T) {
ts := newBackendMgrTester(t)
attr1 := map[string]string{"k1": "v1"}
attr2 := map[string]string{"k2": "v2"}
runners := []runner{
// 1st handshake
{
client: func(packetIO *pnet.PacketIO) error {
ts.mc.attrs = attr1
return ts.mc.authenticate(packetIO)
},
proxy: ts.firstHandshake4Proxy,
backend: func(packetIO *pnet.PacketIO) error {
err := ts.handshake4Backend(packetIO)
require.NoError(t, err)
require.Equal(t, attr1, ts.mb.attrs)
return nil
},
},
// 2nd handshake
{
proxy: ts.redirectSucceed4Proxy,
backend: func(packetIO *pnet.PacketIO) error {
err := ts.redirectSucceed4Backend(packetIO)
require.NoError(t, err)
require.Equal(t, attr1, ts.mb.attrs)
return nil
},
},
// CHANGE_USER updates attrs
{
client: func(packetIO *pnet.PacketIO) error {
ts.mc.cmd = pnet.ComChangeUser
ts.mc.attrs = attr2
return ts.mc.request(packetIO)
},
proxy: ts.forwardCmd4Proxy,
backend: func(packetIO *pnet.PacketIO) error {
err := ts.respondWithNoTxn4Backend(packetIO)
require.NoError(t, err)
require.Equal(t, attr2, ts.mb.attrs)
return nil
},
},
// 2nd handshake
{
proxy: ts.redirectSucceed4Proxy,
backend: func(packetIO *pnet.PacketIO) error {
err := ts.redirectSucceed4Backend(packetIO)
require.NoError(t, err)
require.Equal(t, attr2, ts.mb.attrs)
return nil
},
},
}
ts.runTests(runners)
}
5 changes: 4 additions & 1 deletion pkg/proxy/backend/cmd_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
gomysql "github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/parser/mysql"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
"go.uber.org/zap"
)

const (
Expand All @@ -25,12 +26,14 @@ type CmdProcessor struct {
capability pnet.Capability
// Only includes in_trans or quit status.
serverStatus uint32
logger *zap.Logger
}

func NewCmdProcessor() *CmdProcessor {
func NewCmdProcessor(logger *zap.Logger) *CmdProcessor {
return &CmdProcessor{
serverStatus: 0,
preparedStmtStatus: make(map[int]uint32),
logger: logger,
}
}

Expand Down
25 changes: 20 additions & 5 deletions pkg/proxy/backend/cmd_processor_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import (
"encoding/binary"
"strings"

gomysql "github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tiproxy/lib/util/errors"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
"github.com/siddontang/go/hack"
"go.uber.org/zap"
)

// executeCmd forwards requests and responses between the client and the backend.
Expand All @@ -38,15 +40,11 @@ func (cp *CmdProcessor) executeCmd(request []byte, clientIO, backendIO *pnet.Pac

func (cp *CmdProcessor) forwardCommand(clientIO, backendIO *pnet.PacketIO, request []byte) error {
cmd := pnet.Command(request[0])
// ComChangeUser is special: we need to modify the packet before forwarding.
if cmd != pnet.ComChangeUser {
if err := backendIO.WritePacket(request, true); err != nil {
return err
}
} else {
user, db := pnet.ParseChangeUser(request)
if err := backendIO.WritePacket(pnet.MakeChangeUser(user, db, unknownAuthPlugin, nil), true); err != nil {
return err
}
}
switch cmd {
case pnet.ComStmtPrepare:
Expand Down Expand Up @@ -271,6 +269,23 @@ func (cp *CmdProcessor) forwardSendLongDataCmd(request []byte) error {
}

func (cp *CmdProcessor) forwardChangeUserCmd(clientIO, backendIO *pnet.PacketIO, request []byte) error {
req, err := pnet.ParseChangeUser(request, cp.capability)
if err != nil {
cp.logger.Warn("parse COM_CHANGE_USER packet encounters error", zap.Error(err))
var warning *errors.Warning
if !errors.As(err, &warning) {
return gomysql.ErrMalformPacket
}
}
// The client may use the TiProxy salt to generate the auth data instead of using the TiDB salt,
// so we need another switch-auth request to pass the TiDB salt to the client.
// See https://github.com/pingcap/tiproxy/issues/127.
req.AuthPlugin = unknownAuthPlugin
req.AuthData = nil
if err := backendIO.WritePacket(pnet.MakeChangeUser(req, cp.capability), true); err != nil {
return err
}

for {
response, err := forwardOnePacket(clientIO, backendIO, true)
if err != nil {
Expand Down
22 changes: 22 additions & 0 deletions pkg/proxy/backend/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding/binary"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tiproxy/lib/util/errors"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
)

Expand Down Expand Up @@ -132,6 +133,22 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh
return nil
}

func (mb *mockBackend) changeUser(pkt []byte) error {
req, err := pnet.ParseChangeUser(pkt, mb.capability)
if err != nil {
return err
}
mb.username = req.User
mb.db = req.DB
mb.attrs = req.Attrs
mb.authData = req.AuthData
mb.authPlugin = req.AuthPlugin
if mb.authPlugin != unknownAuthPlugin {
return errors.New("should use different auth plugin")
}
return nil
}

func (mb *mockBackend) respond(packetIO *pnet.PacketIO) error {
if mb.abnormalExit {
return packetIO.Close()
Expand All @@ -150,6 +167,11 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error {
if err != nil {
return err
}
if pnet.Command(pkt[0]) == pnet.ComChangeUser {
if err := mb.changeUser(pkt); err != nil {
return err
}
}
switch mb.respondType {
case responseTypeOK:
return mb.respondOK(packetIO)
Expand Down
10 changes: 9 additions & 1 deletion pkg/proxy/backend/mock_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,15 @@ func (mc *mockClient) request(packetIO *pnet.PacketIO) error {
}

func (mc *mockClient) requestChangeUser(packetIO *pnet.PacketIO) error {
data := pnet.MakeChangeUser(mc.username, mc.dbName, mysql.AuthNativePassword, mc.authData)
req := &pnet.ChangeUserReq{
User: mc.username,
DB: mc.dbName,
AuthPlugin: mysql.AuthNativePassword,
AuthData: mc.authData,
Charset: []byte{0x11, 0x22},
Attrs: mc.attrs,
}
data := pnet.MakeChangeUser(req, mc.capability)
if err := packetIO.WritePacket(data, true); err != nil {
return err
}
Expand Down
6 changes: 5 additions & 1 deletion pkg/proxy/backend/testsuite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ func (ts *testSuite) changeDB(db string) {
func (ts *testSuite) changeUser(username, db string) {
ts.mc.username = username
ts.mc.dbName = db
ts.mp.authenticator.changeUser(username, db)
req := &pnet.ChangeUserReq{
User: username,
DB: db,
}
ts.mp.authenticator.changeUser(req)
}

func (ts *testSuite) runAndCheck(t *testing.T, c checker, clientRunner, backendRunner func(*pnet.PacketIO) error,
Expand Down
Loading

0 comments on commit 8fe7525

Please sign in to comment.