diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 28432187..75efc324 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -99,7 +99,8 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte proxyCapability ^= pnet.ClientSSL } - if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, mysql.AuthNativePassword, handshakeHandler.GetServerVersion()); err != nil { + cid, _ := cctx.Value(ConnContextKeyConnID).(uint64) + if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, mysql.AuthNativePassword, handshakeHandler.GetServerVersion(), cid); err != nil { return err } pkt, isSSL, err := clientIO.ReadSSLRequestOrHandshakeResp() @@ -283,7 +284,7 @@ func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serve err = pnet.ParseErrorPacket(serverPkt) return } - capability, _ = pnet.ParseInitialHandshake(serverPkt) + capability, _, _ = pnet.ParseInitialHandshake(serverPkt) return } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 3d9f4d01..3e936beb 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -142,6 +142,7 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler redirectResCh: make(chan *redirectResult, 1), quitSource: SrcClientQuit, } + mgr.SetValue(ConnContextKeyConnID, connectionID) return mgr } diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 74a7f319..e1640888 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -983,3 +983,22 @@ func TestKeepAlive(t *testing.T) { } ts.runTests(runners) } + +func TestConnID(t *testing.T) { + ids := []uint64{0, 4, 9} + for _, id := range ids { + ts := newBackendMgrTester(t, func(config *testConfig) { + config.proxyConfig.connectionID = id + }) + runners := []runner{{ + client: func(packetIO *pnet.PacketIO) error { + err := ts.mc.authenticate(packetIO) + require.Equal(t, ts.mc.connid, id) + return err + }, + proxy: ts.firstHandshake4Proxy, + backend: ts.handshake4Backend, + }} + ts.runTests(runners) + } +} diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index d7a4389d..8cf4c198 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -16,6 +16,7 @@ type ConnContextKey string const ( ConnContextKeyTLSState ConnContextKey = "tls-state" + ConnContextKeyConnID ConnContextKey = "conn-id" ) type ErrorSource int diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index f074d2ed..a900edd0 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -64,7 +64,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error { } var err error // write initial handshake - if err = packetIO.WriteInitialHandshake(mb.capability, mb.salt, mb.authPlugin, pnet.ServerVersion); err != nil { + if err = packetIO.WriteInitialHandshake(mb.capability, mb.salt, mb.authPlugin, pnet.ServerVersion, 100); err != nil { return err } // read the response diff --git a/pkg/proxy/backend/mock_client_test.go b/pkg/proxy/backend/mock_client_test.go index 4774b7ca..1db06799 100644 --- a/pkg/proxy/backend/mock_client_test.go +++ b/pkg/proxy/backend/mock_client_test.go @@ -45,7 +45,8 @@ func newClientConfig() *clientConfig { } type mockClient struct { - err error + connid uint64 + err error // Inputs that assigned by the test and will be sent to the server. *clientConfig // Outputs that received from the server and will be checked by the test. @@ -68,9 +69,10 @@ func (mc *mockClient) authenticate(packetIO *pnet.PacketIO) error { if err != nil { return err } - serverCap, serverVersion := pnet.ParseInitialHandshake(pkt) + serverCap, connid, serverVersion := pnet.ParseInitialHandshake(pkt) mc.capability = mc.capability & serverCap mc.serverVersion = serverVersion + mc.connid = connid resp := &pnet.HandshakeResp{ User: mc.username, diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 9fcac1a0..20d6c773 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -22,6 +22,7 @@ type proxyConfig struct { sessionToken string capability pnet.Capability waitRedirect bool + connectionID uint64 } func newProxyConfig() *proxyConfig { @@ -50,7 +51,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { mp := &mockProxy{ proxyConfig: cfg, logger: lg.Named("mockProxy"), - BackendConnManager: NewBackendConnManager(lg, cfg.handler, 0, &BCConfig{ + BackendConnManager: NewBackendConnManager(lg, cfg.handler, cfg.connectionID, &BCConfig{ CheckBackendInterval: cfg.checkBackendInterval, }), } diff --git a/pkg/proxy/net/mysql.go b/pkg/proxy/net/mysql.go index 9e7246a6..21288b4c 100644 --- a/pkg/proxy/net/mysql.go +++ b/pkg/proxy/net/mysql.go @@ -23,16 +23,15 @@ const ( var ( ServerVersion = mysql.ServerVersion Collation = uint8(mysql.DefaultCollationID) - ConnID = 100 Status = mysql.ServerStatusAutocommit ) // ParseInitialHandshake parses the initial handshake received from the server. -func ParseInitialHandshake(data []byte) (Capability, string) { +func ParseInitialHandshake(data []byte) (Capability, uint64, string) { // skip min version serverVersion := string(data[1 : 1+bytes.IndexByte(data[1:], 0)]) pos := 1 + len(serverVersion) + 1 - // skip connection id + connid := uint32(binary.LittleEndian.Uint32(data[pos : pos+4])) // skip salt first part // skip filter pos += 4 + 8 + 1 @@ -52,7 +51,7 @@ func ParseInitialHandshake(data []byte) (Capability, string) { // skip salt second part // skip auth plugin } - return Capability(capability), serverVersion + return Capability(capability), uint64(connid), serverVersion } // HandshakeResp indicates the response read from the client. diff --git a/pkg/proxy/net/packetio_mysql.go b/pkg/proxy/net/packetio_mysql.go index 2132d658..a1b77646 100644 --- a/pkg/proxy/net/packetio_mysql.go +++ b/pkg/proxy/net/packetio_mysql.go @@ -18,7 +18,7 @@ var ( // WriteInitialHandshake writes an initial handshake as a server. // It's used for tenant-aware routing and testing. -func (p *PacketIO) WriteInitialHandshake(capability Capability, salt []byte, authPlugin string, serverVersion string) error { +func (p *PacketIO) WriteInitialHandshake(capability Capability, salt []byte, authPlugin string, serverVersion string, connID uint64) error { saltLen := len(salt) if saltLen < 8 { return ErrSaltNotLongEnough @@ -34,7 +34,7 @@ func (p *PacketIO) WriteInitialHandshake(capability Capability, salt []byte, aut data = append(data, serverVersion...) data = append(data, 0) // connection id - data = append(data, byte(ConnID), byte(ConnID>>8), byte(ConnID>>16), byte(ConnID>>24)) + data = append(data, byte(connID), byte(connID>>8), byte(connID>>16), byte(connID>>24)) // auth-plugin-data-part-1 data = append(data, salt[0:8]...) // filler [00] diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index 701941ed..59140a0f 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -93,9 +93,9 @@ func TestPacketIO(t *testing.T) { } // send handshake - require.NoError(t, srv.WriteInitialHandshake(0, salt[:], AuthNativePassword, ServerVersion)) + require.NoError(t, srv.WriteInitialHandshake(0, salt[:], AuthNativePassword, ServerVersion, 100)) // salt should not be long enough - require.ErrorIs(t, srv.WriteInitialHandshake(0, make([]byte, 4), AuthNativePassword, ServerVersion), ErrSaltNotLongEnough) + require.ErrorIs(t, srv.WriteInitialHandshake(0, make([]byte, 4), AuthNativePassword, ServerVersion, 100), ErrSaltNotLongEnough) // expect correct and wrong capability flags _, isSSL, err := srv.ReadSSLRequestOrHandshakeResp()