Skip to content

Commit

Permalink
proxy: use a real proxy conn id (#348)
Browse files Browse the repository at this point in the history
Signed-off-by: xhe <xw897002528@gmail.com>
  • Loading branch information
xhebox authored Aug 30, 2023
1 parent bcb0848 commit efd29b1
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 14 deletions.
5 changes: 3 additions & 2 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
proxyCapability ^= pnet.ClientSSL
}

if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, mysql.AuthNativePassword, handshakeHandler.GetServerVersion()); err != nil {
cid, _ := cctx.Value(ConnContextKeyConnID).(uint64)
if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, mysql.AuthNativePassword, handshakeHandler.GetServerVersion(), cid); err != nil {
return err
}
pkt, isSSL, err := clientIO.ReadSSLRequestOrHandshakeResp()
Expand Down Expand Up @@ -283,7 +284,7 @@ func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serve
err = pnet.ParseErrorPacket(serverPkt)
return
}
capability, _ = pnet.ParseInitialHandshake(serverPkt)
capability, _, _ = pnet.ParseInitialHandshake(serverPkt)
return
}

Expand Down
1 change: 1 addition & 0 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler
redirectResCh: make(chan *redirectResult, 1),
quitSource: SrcClientQuit,
}
mgr.SetValue(ConnContextKeyConnID, connectionID)
return mgr
}

Expand Down
19 changes: 19 additions & 0 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -983,3 +983,22 @@ func TestKeepAlive(t *testing.T) {
}
ts.runTests(runners)
}

func TestConnID(t *testing.T) {
ids := []uint64{0, 4, 9}
for _, id := range ids {
ts := newBackendMgrTester(t, func(config *testConfig) {
config.proxyConfig.connectionID = id
})
runners := []runner{{
client: func(packetIO *pnet.PacketIO) error {
err := ts.mc.authenticate(packetIO)
require.Equal(t, ts.mc.connid, id)
return err
},
proxy: ts.firstHandshake4Proxy,
backend: ts.handshake4Backend,
}}
ts.runTests(runners)
}
}
1 change: 1 addition & 0 deletions pkg/proxy/backend/handshake_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type ConnContextKey string

const (
ConnContextKeyTLSState ConnContextKey = "tls-state"
ConnContextKeyConnID ConnContextKey = "conn-id"
)

type ErrorSource int
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/backend/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error {
}
var err error
// write initial handshake
if err = packetIO.WriteInitialHandshake(mb.capability, mb.salt, mb.authPlugin, pnet.ServerVersion); err != nil {
if err = packetIO.WriteInitialHandshake(mb.capability, mb.salt, mb.authPlugin, pnet.ServerVersion, 100); err != nil {
return err
}
// read the response
Expand Down
6 changes: 4 additions & 2 deletions pkg/proxy/backend/mock_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func newClientConfig() *clientConfig {
}

type mockClient struct {
err error
connid uint64
err error
// Inputs that assigned by the test and will be sent to the server.
*clientConfig
// Outputs that received from the server and will be checked by the test.
Expand All @@ -68,9 +69,10 @@ func (mc *mockClient) authenticate(packetIO *pnet.PacketIO) error {
if err != nil {
return err
}
serverCap, serverVersion := pnet.ParseInitialHandshake(pkt)
serverCap, connid, serverVersion := pnet.ParseInitialHandshake(pkt)
mc.capability = mc.capability & serverCap
mc.serverVersion = serverVersion
mc.connid = connid

resp := &pnet.HandshakeResp{
User: mc.username,
Expand Down
3 changes: 2 additions & 1 deletion pkg/proxy/backend/mock_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type proxyConfig struct {
sessionToken string
capability pnet.Capability
waitRedirect bool
connectionID uint64
}

func newProxyConfig() *proxyConfig {
Expand Down Expand Up @@ -50,7 +51,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy {
mp := &mockProxy{
proxyConfig: cfg,
logger: lg.Named("mockProxy"),
BackendConnManager: NewBackendConnManager(lg, cfg.handler, 0, &BCConfig{
BackendConnManager: NewBackendConnManager(lg, cfg.handler, cfg.connectionID, &BCConfig{
CheckBackendInterval: cfg.checkBackendInterval,
}),
}
Expand Down
7 changes: 3 additions & 4 deletions pkg/proxy/net/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,15 @@ const (
var (
ServerVersion = mysql.ServerVersion
Collation = uint8(mysql.DefaultCollationID)
ConnID = 100
Status = mysql.ServerStatusAutocommit
)

// ParseInitialHandshake parses the initial handshake received from the server.
func ParseInitialHandshake(data []byte) (Capability, string) {
func ParseInitialHandshake(data []byte) (Capability, uint64, string) {
// skip min version
serverVersion := string(data[1 : 1+bytes.IndexByte(data[1:], 0)])
pos := 1 + len(serverVersion) + 1
// skip connection id
connid := uint32(binary.LittleEndian.Uint32(data[pos : pos+4]))
// skip salt first part
// skip filter
pos += 4 + 8 + 1
Expand All @@ -52,7 +51,7 @@ func ParseInitialHandshake(data []byte) (Capability, string) {
// skip salt second part
// skip auth plugin
}
return Capability(capability), serverVersion
return Capability(capability), uint64(connid), serverVersion
}

// HandshakeResp indicates the response read from the client.
Expand Down
4 changes: 2 additions & 2 deletions pkg/proxy/net/packetio_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var (

// WriteInitialHandshake writes an initial handshake as a server.
// It's used for tenant-aware routing and testing.
func (p *PacketIO) WriteInitialHandshake(capability Capability, salt []byte, authPlugin string, serverVersion string) error {
func (p *PacketIO) WriteInitialHandshake(capability Capability, salt []byte, authPlugin string, serverVersion string, connID uint64) error {
saltLen := len(salt)
if saltLen < 8 {
return ErrSaltNotLongEnough
Expand All @@ -34,7 +34,7 @@ func (p *PacketIO) WriteInitialHandshake(capability Capability, salt []byte, aut
data = append(data, serverVersion...)
data = append(data, 0)
// connection id
data = append(data, byte(ConnID), byte(ConnID>>8), byte(ConnID>>16), byte(ConnID>>24))
data = append(data, byte(connID), byte(connID>>8), byte(connID>>16), byte(connID>>24))
// auth-plugin-data-part-1
data = append(data, salt[0:8]...)
// filler [00]
Expand Down
4 changes: 2 additions & 2 deletions pkg/proxy/net/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ func TestPacketIO(t *testing.T) {
}

// send handshake
require.NoError(t, srv.WriteInitialHandshake(0, salt[:], AuthNativePassword, ServerVersion))
require.NoError(t, srv.WriteInitialHandshake(0, salt[:], AuthNativePassword, ServerVersion, 100))
// salt should not be long enough
require.ErrorIs(t, srv.WriteInitialHandshake(0, make([]byte, 4), AuthNativePassword, ServerVersion), ErrSaltNotLongEnough)
require.ErrorIs(t, srv.WriteInitialHandshake(0, make([]byte, 4), AuthNativePassword, ServerVersion, 100), ErrSaltNotLongEnough)

// expect correct and wrong capability flags
_, isSSL, err := srv.ReadSSLRequestOrHandshakeResp()
Expand Down

0 comments on commit efd29b1

Please sign in to comment.