Skip to content

Commit

Permalink
backend: handshake with client first (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Oct 10, 2022
1 parent b1e0a00 commit 3432991
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 71 deletions.
87 changes: 49 additions & 38 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ var (
ErrCapabilityNegotiation = errors.New("capability negotiation failed")
)

// Other server capabilities are not supported.
const requiredCapabilities = pnet.ClientProtocol41
const requiredFrontendCaps = pnet.ClientProtocol41
const requiredBackendCaps = pnet.ClientDeprecateEOF | pnet.ClientSSL

// Other server capabilities are not supported. ClientDeprecateEOF is supported but TiDB 6.2.0 doesn't support it now.
const supportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientConnectWithDB |
pnet.ClientODBC | pnet.ClientLocalFiles | pnet.ClientInteractive | pnet.ClientSSL | pnet.ClientLongFlag |
pnet.ClientODBC | pnet.ClientLocalFiles | pnet.ClientInteractive | pnet.ClientLongFlag |
pnet.ClientTransactions | pnet.ClientReserved | pnet.ClientSecureConnection | pnet.ClientMultiStatements |
pnet.ClientMultiResults | pnet.ClientPluginAuth | pnet.ClientConnectAttrs | pnet.ClientPluginAuthLenencClientData |
pnet.ClientDeprecateEOF | requiredCapabilities
requiredFrontendCaps | requiredBackendCaps

// Authenticator handshakes with the client and the backend.
type Authenticator struct {
Expand Down Expand Up @@ -65,31 +67,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO, back
proxyCapability ^= pnet.ClientSSL
}

// read backend initial handshake
backendHandshake, backendCapabilityU, err := auth.readInitialHandshake(backendIO)
if err != nil {
return err
}
backendCapability := pnet.Capability(backendCapabilityU)
if backendCapability&pnet.ClientSSL == 0 {
// The error cannot be sent to the client because the client only expects an initial handshake packet.
// The only way is to log it and disconnect.
return errors.New("the TiDB server must enable TLS")
}
if common := proxyCapability & backendCapability; (proxyCapability^common)&^pnet.ClientSSL != 0 {
// TODO: need to do negotiation with backend
// 1. proxyCapability &= backendCapability
// 2. binary.LittleEndian.PutUint32(backendHandshake, proxyCapability.Uint32())
//
// it should exchange caps with the backend
// but TiDB does not send all of its supported capabilities
// thus we must ignore server capabilities
// however, I will log something
logger.Info("backend does not support capabilities from proxy", zap.Stringer("common", common), zap.Stringer("proxy", proxyCapability^common), zap.Stringer("backend", backendCapability^common))
}

// forward backend handshake
if err := clientIO.WritePacket(backendHandshake, true); err != nil {
if err := clientIO.WriteInitialHandshake(proxyCapability.Uint32(), make([]byte, 20), mysql.AuthNativePassword); err != nil {
return err
}
pkt, isSSL, err := clientIO.ReadSSLRequestOrHandshakeResp()
Expand All @@ -116,9 +94,9 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO, back
} else {
binary.LittleEndian.PutUint32(pkt, (frontendCapability | pnet.ClientSSL).Uint32())
}
if commonCaps := frontendCapability & requiredCapabilities; commonCaps != requiredCapabilities {
logger.Error("require frontend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredCapabilities))
return errors.Wrapf(ErrCapabilityNegotiation, "require %s", requiredCapabilities&^commonCaps)
if commonCaps := frontendCapability & requiredFrontendCaps; commonCaps != requiredFrontendCaps {
logger.Error("require frontend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredFrontendCaps))
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps)
}
commonCaps := frontendCapability & proxyCapability
if frontendCapability^commonCaps != 0 {
Expand Down Expand Up @@ -148,6 +126,34 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO, back
}
}

// read backend initial handshake
_, backendCapabilityU, err := auth.readInitialHandshake(backendIO)
if err != nil {
return err
}
backendCapability := pnet.Capability(backendCapabilityU)
if commonCaps := backendCapability & requiredBackendCaps; commonCaps != requiredBackendCaps {
// The error cannot be sent to the client because the client only expects an initial handshake packet.
// The only way is to log it and disconnect.
logger.Error("require backend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredBackendCaps))
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps&^commonCaps)
}
if common := proxyCapability & backendCapability; (proxyCapability^common)&^pnet.ClientSSL != 0 {
// TODO: need to do negotiation with backend
// 1. proxyCapability &= backendCapability
// 2. binary.LittleEndian.PutUint32(backendHandshake, proxyCapability.Uint32())
//
// it should exchange caps with the backend
// but TiDB does not send all of its supported capabilities
// thus we must ignore server capabilities
// however, I will log something
logger.Info("backend does not support capabilities from proxy", zap.Stringer("common", common), zap.Stringer("proxy", proxyCapability^common), zap.Stringer("backend", backendCapability^common))
}

resp.Capability = auth.capability | mysql.ClientSSL
// Send an unknown auth plugin so that the backend will request the auth data again.
resp.AuthPlugin = "auth_unknown_plugin"
pkt = pnet.MakeHandshakeResponse(resp)
// write SSL Packet
if err := backendIO.WritePacket(pkt[:32], true); err != nil {
return err
Expand Down Expand Up @@ -235,12 +241,17 @@ func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serve
}

func (auth *Authenticator) writeAuthHandshake(backendIO *pnet.PacketIO, authData []byte) error {
// Always handshake with SSL enabled.
capability := auth.capability | mysql.ClientSSL
// Always enable auth_plugin.
capability |= mysql.ClientPluginAuth
data := pnet.MakeHandshakeResponse(auth.user, auth.dbname, mysql.AuthTiDBSessionToken,
auth.collation, authData, auth.attrs, capability)
// Always handshake with SSL enabled and enable auth_plugin.
resp := &pnet.HandshakeResp{
User: auth.user,
DB: auth.dbname,
AuthPlugin: mysql.AuthTiDBSessionToken,
Attrs: auth.attrs,
AuthData: authData,
Capability: auth.capability | mysql.ClientSSL | mysql.ClientPluginAuth,
Collation: auth.collation,
}
data := pnet.MakeHandshakeResponse(resp)

// write SSL req
if err := backendIO.WritePacket(data[:32], true); err != nil {
Expand Down
22 changes: 11 additions & 11 deletions pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ func TestUnsupportedCapability(t *testing.T) {
cfg.backendConfig.capability = defaultTestBackendCapability | mysql.ClientSSL
},
},
{
func(cfg *testConfig) {
cfg.backendConfig.capability = defaultTestBackendCapability & ^mysql.ClientDeprecateEOF
},
func(cfg *testConfig) {
cfg.backendConfig.capability = defaultTestBackendCapability | mysql.ClientDeprecateEOF
},
},
{
func(cfg *testConfig) {
cfg.clientConfig.capability = defaultTestClientCapability & ^mysql.ClientProtocol41
Expand Down Expand Up @@ -71,9 +79,9 @@ func TestUnsupportedCapability(t *testing.T) {
for _, cfgs := range cfgOverriders {
ts, clean := newTestSuite(t, tc, cfgs...)
ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite) {
if ts.mb.backendConfig.capability&mysql.ClientSSL == 0 {
require.ErrorContains(t, ts.mp.err, "must enable TLS")
} else if ts.mc.clientConfig.capability&mysql.ClientProtocol41 == 0 {
if ts.mb.backendConfig.capability&requiredBackendCaps.Uint32() != requiredBackendCaps.Uint32() {
require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation)
} else if ts.mc.clientConfig.capability&requiredFrontendCaps.Uint32() != requiredFrontendCaps.Uint32() {
require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation)
} else {
require.NoError(t, ts.mc.err)
Expand Down Expand Up @@ -119,14 +127,6 @@ func TestAuthPlugin(t *testing.T) {
cfg.backendConfig.authPlugin = mysql.AuthCachingSha2Password
},
},
{
func(cfg *testConfig) {
cfg.backendConfig.switchAuth = true
},
func(cfg *testConfig) {
cfg.backendConfig.switchAuth = false
},
},
{
func(cfg *testConfig) {
cfg.backendConfig.authSucceed = true
Expand Down
4 changes: 1 addition & 3 deletions pkg/proxy/backend/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ type backendConfig struct {
capability uint32
status uint16
authSucceed bool
switchAuth bool
abnormalExit bool
}

Expand All @@ -46,7 +45,6 @@ func newBackendConfig() *backendConfig {
capability: defaultTestBackendCapability,
salt: mockSalt,
authPlugin: mysql.AuthCachingSha2Password,
switchAuth: true,
authSucceed: true,
loops: 1,
stmtNum: 1,
Expand Down Expand Up @@ -109,7 +107,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error {
}

func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.HandshakeResp) error {
if resp.AuthPlugin != mysql.AuthTiDBSessionToken && mb.switchAuth {
if resp.AuthPlugin != mysql.AuthTiDBSessionToken {
var err error
if err = packetIO.WriteSwitchRequest(mb.authPlugin, mb.salt); err != nil {
return err
Expand Down
15 changes: 12 additions & 3 deletions pkg/proxy/backend/mock_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,25 @@ func (mc *mockClient) authenticate(packetIO *pnet.PacketIO) error {
return err
}

resp := pnet.MakeHandshakeResponse(mc.username, mc.dbName, mc.authPlugin, mc.collation, mc.authData, mc.attrs, mc.capability)
resp := &pnet.HandshakeResp{
User: mc.username,
DB: mc.dbName,
AuthPlugin: mc.authPlugin,
Attrs: mc.attrs,
AuthData: mc.authData,
Capability: mc.capability,
Collation: mc.collation,
}
pkt := pnet.MakeHandshakeResponse(resp)
if mc.capability&mysql.ClientSSL > 0 {
if err := packetIO.WritePacket(resp[:32], true); err != nil {
if err := packetIO.WritePacket(pkt[:32], true); err != nil {
return err
}
if err := packetIO.ClientTLSHandshake(mc.tlsConfig); err != nil {
return err
}
}
if err := packetIO.WritePacket(resp, true); err != nil {
if err := packetIO.WritePacket(pkt, true); err != nil {
return err
}
return mc.writePassword(packetIO)
Expand Down
2 changes: 0 additions & 2 deletions pkg/proxy/backend/testsuite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ func (ts *testSuite) authenticateFirstTime(t *testing.T, c checker) {
// The proxy reconnects to the proxy using preserved client data.
// This must be called after authenticateFirstTime.
func (ts *testSuite) authenticateSecondTime(t *testing.T, c checker) {
// The server won't request switching auth-plugin this time.
ts.mb.backendConfig.switchAuth = false
ts.mb.backendConfig.authSucceed = true
ts.runAndCheck(t, c, nil, ts.mb.authenticate, ts.mp.authenticateSecondTime)
if c == nil {
Expand Down
27 changes: 14 additions & 13 deletions pkg/proxy/net/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,23 +143,24 @@ func ParseHandshakeResponse(data []byte) *HandshakeResp {
return resp
}

func MakeHandshakeResponse(username, db, authPlugin string, collation uint8, authData, attrs []byte, capability uint32) []byte {
func MakeHandshakeResponse(resp *HandshakeResp) []byte {
// encode length of the auth data
var (
authRespBuf, attrRespBuf [9]byte
authResp, attrResp []byte
)
authResp = DumpLengthEncodedInt(authRespBuf[:0], uint64(len(authData)))
authResp = DumpLengthEncodedInt(authRespBuf[:0], uint64(len(resp.AuthData)))
capability := resp.Capability
if len(authResp) > 1 {
capability |= mysql.ClientPluginAuthLenencClientData
} else {
capability &= ^mysql.ClientPluginAuthLenencClientData
}
if capability&mysql.ClientConnectAtts > 0 {
attrResp = DumpLengthEncodedInt(attrRespBuf[:0], uint64(len(attrs)))
attrResp = DumpLengthEncodedInt(attrRespBuf[:0], uint64(len(resp.Attrs)))
}

length := 4 + 4 + 1 + 23 + len(username) + 1 + len(authResp) + len(authData) + len(db) + 1 + len(authPlugin) + 1 + len(attrResp) + len(attrs)
length := 4 + 4 + 1 + 23 + len(resp.User) + 1 + len(authResp) + len(resp.AuthData) + len(resp.DB) + 1 + len(resp.AuthPlugin) + 1 + len(attrResp) + len(resp.Attrs)
data := make([]byte, length)
pos := 0
// capability [32 bit]
Expand All @@ -168,48 +169,48 @@ func MakeHandshakeResponse(username, db, authPlugin string, collation uint8, aut
// MaxPacketSize [32 bit]
pos += 4
// Charset [1 byte]
data[pos] = collation
data[pos] = resp.Collation
pos++
// Filler [23 bytes] (all 0x00)
pos += 23

// User [null terminated string]
pos += copy(data[pos:], username)
pos += copy(data[pos:], resp.User)
data[pos] = 0x00
pos++

// auth data
if capability&mysql.ClientPluginAuthLenencClientData > 0 {
pos += copy(data[pos:], authResp)
pos += copy(data[pos:], authData)
pos += copy(data[pos:], resp.AuthData)
} else if capability&mysql.ClientSecureConnection > 0 {
data[pos] = byte(len(authData))
data[pos] = byte(len(resp.AuthData))
pos++
pos += copy(data[pos:], authData)
pos += copy(data[pos:], resp.AuthData)
} else {
pos += copy(data[pos:], authData)
pos += copy(data[pos:], resp.AuthData)
data[pos] = 0x00
pos++
}

// db [null terminated string]
if capability&mysql.ClientConnectWithDB > 0 {
pos += copy(data[pos:], db)
pos += copy(data[pos:], resp.DB)
data[pos] = 0x00
pos++
}

// auth_plugin [null terminated string]
if capability&mysql.ClientPluginAuth > 0 {
pos += copy(data[pos:], authPlugin)
pos += copy(data[pos:], resp.AuthPlugin)
data[pos] = 0x00
pos++
}

// attrs
if capability&mysql.ClientConnectAtts > 0 {
pos += copy(data[pos:], attrResp)
pos += copy(data[pos:], attrs)
pos += copy(data[pos:], resp.Attrs)
}
return data[:pos]
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/net/packetio_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (p *PacketIO) ReadSSLRequestOrHandshakeResp() (pkt []byte, isSSL bool, err
return
}

capability := uint32(binary.LittleEndian.Uint32(pkt[:4]))
capability := binary.LittleEndian.Uint32(pkt[:4])
isSSL = capability&mysql.ClientSSL != 0
return
}
Expand Down

0 comments on commit 3432991

Please sign in to comment.