Skip to content

Commit

Permalink
backend: use BackendConnMgr as ConnContext (pingcap#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhebox committed Mar 7, 2023
1 parent c4ec5e9 commit 8b4557c
Show file tree
Hide file tree
Showing 17 changed files with 279 additions and 305 deletions.
4 changes: 2 additions & 2 deletions pkg/manager/router/backend_selector.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ func (bs *BackendSelector) Reset() {
}

func (bs *BackendSelector) Next() string {
if len(bs.cur) > 0 {
bs.cur = bs.routeOnce(bs.excluded)
if bs.cur != "" {
bs.excluded = append(bs.excluded, bs.cur)
}
bs.cur = bs.routeOnce(bs.excluded)
return bs.cur
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/manager/router/router_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ package router
var _ Router = &StaticRouter{}

type StaticRouter struct {
addr []string
cnt int
addrs []string
cnt int
}

func NewStaticRouter(addr []string) *StaticRouter {
return &StaticRouter{addr: addr}
return &StaticRouter{addrs: addr}
}

func (r *StaticRouter) GetBackendSelector() BackendSelector {
return BackendSelector{
routeOnce: func(excluded []string) string {
for _, addr := range r.addr {
for _, addr := range r.addrs {
found := false
for _, e := range excluded {
if e == addr {
Expand Down
153 changes: 59 additions & 94 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"encoding/binary"
"fmt"
"net"
"sync"
"time"

"github.com/pingcap/TiProxy/lib/util/errors"
Expand Down Expand Up @@ -47,19 +46,14 @@ const SupportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRo

// Authenticator handshakes with the client and the backend.
type Authenticator struct {
backendTLSConfig *tls.Config
ctxmap sync.Map
supportedServerCapabilities pnet.Capability
dbname string // default database name
serverAddr string
clientAddr string
user string
attrs map[string]string
salt []byte
capability uint32 // client capability
collation uint8
proxyProtocol bool
requireBackendTLS bool
dbname string // default database name
user string
attrs map[string]string
salt []byte
capability uint32 // client capability
collation uint8
proxyProtocol bool
requireBackendTLS bool
}

func (auth *Authenticator) String() string {
Expand Down Expand Up @@ -104,17 +98,16 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili

type backendIOGetter func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp, timeout time.Duration) (*pnet.PacketIO, error)

func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler,
func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnContext, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler,
getBackendIO backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error {
clientIO.ResetSequence()
auth.clientAddr = clientIO.SourceAddr().String()

proxyCapability := auth.supportedServerCapabilities
proxyCapability := handshakeHandler.GetCapability()
if frontendTLSConfig == nil {
proxyCapability ^= pnet.ClientSSL
}

if err := clientIO.WriteInitialHandshake(proxyCapability.Uint32(), auth.salt, mysql.AuthNativePassword); err != nil {
if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, mysql.AuthNativePassword); err != nil {
return err
}
pkt, isSSL, err := clientIO.ReadSSLRequestOrHandshakeResp()
Expand Down Expand Up @@ -151,17 +144,17 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet
}
auth.capability = commonCaps.Uint32()

resp := pnet.ParseHandshakeResponse(pkt)
if err = handshakeHandler.HandleHandshakeResp(auth, resp); err != nil {
clientResp := pnet.ParseHandshakeResponse(pkt)
if err = handshakeHandler.HandleHandshakeResp(cctx, clientResp); err != nil {
return err
}
auth.user = resp.User
auth.dbname = resp.DB
auth.collation = resp.Collation
auth.attrs = resp.Attrs
auth.user = clientResp.User
auth.dbname = clientResp.DB
auth.collation = clientResp.Collation
auth.attrs = clientResp.Attrs

// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
backendIO, err := getBackendIO(auth, auth, resp, 5*time.Second)
backendIO, err := getBackendIO(cctx, auth, clientResp, 5*time.Second)
if err != nil {
return err
}
Expand All @@ -173,11 +166,10 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet
}

// read backend initial handshake
_, backendCapabilityU, err := auth.readInitialHandshake(backendIO)
_, backendCapability, err := auth.readInitialHandshake(backendIO)
if err != nil {
return err
}
backendCapability := pnet.Capability(backendCapabilityU)

if err := auth.verifyBackendCaps(logger, backendCapability); err != nil {
return err
Expand All @@ -195,38 +187,12 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet
logger.Info("backend does not support capabilities from proxy", zap.Stringer("common", common), zap.Stringer("proxy", proxyCapability^common), zap.Stringer("backend", backendCapability^common))
}

// Send an unknown auth plugin so that the backend will request the auth data again.
resp.AuthPlugin = unknownAuthPlugin
resp.Capability = auth.capability

if backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil {
resp.Capability |= mysql.ClientSSL
pkt = pnet.MakeHandshakeResponse(resp)
// write SSL Packet
if err := backendIO.WritePacket(pkt[:32], true); err != nil {
return err
}
auth.backendTLSConfig = backendTLSConfig.Clone()
addr := backendIO.RemoteAddr().String()
if auth.serverAddr != "" {
// NOTE: should use DNS name as much as possible
// Usually certs are signed with domain instead of IP addrs
// And `RemoteAddr()` will return IP addr
addr = auth.serverAddr
}
host, _, err := net.SplitHostPort(addr)
if err == nil {
auth.backendTLSConfig.ServerName = host
}
if err = backendIO.ClientTLSHandshake(auth.backendTLSConfig); err != nil {
return err
}
} else {
pkt = pnet.MakeHandshakeResponse(resp)
}

// forward client handshake resp
if err := backendIO.WritePacket(pkt, true); err != nil {
if err := auth.writeAuthHandshake(
backendIO, backendTLSConfig, backendCapability,
// send an unknown auth plugin so that the backend will request the auth data again.
unknownAuthPlugin, nil, 0,
); err != nil {
return err
}

Expand Down Expand Up @@ -258,7 +224,7 @@ func forwardMsg(srcIO, destIO *pnet.PacketIO) (data []byte, err error) {
return
}

func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, sessionToken string) error {
func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, backendTLSConfig *tls.Config, sessionToken string) error {
if len(sessionToken) == 0 {
return errors.New("session token is empty")
}
Expand All @@ -268,24 +234,26 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac
return err
}

_, serverCapability, err := auth.readInitialHandshake(backendIO)
_, backendCapability, err := auth.readInitialHandshake(backendIO)
if err != nil {
return err
}

if err := auth.verifyBackendCaps(logger, pnet.Capability(serverCapability)); err != nil {
if err := auth.verifyBackendCaps(logger, pnet.Capability(backendCapability)); err != nil {
return err
}

tokenBytes := hack.Slice(sessionToken)
if err = auth.writeAuthHandshake(backendIO, tokenBytes, serverCapability&mysql.ClientSSL != 0); err != nil {
if err = auth.writeAuthHandshake(
backendIO, backendTLSConfig, backendCapability,
mysql.AuthTiDBSessionToken, hack.Slice(sessionToken), mysql.ClientPluginAuth,
); err != nil {
return err
}

return auth.handleSecondAuthResult(backendIO)
}

func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability uint32, err error) {
func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) {
if serverPkt, err = backendIO.ReadPacket(); err != nil {
return
}
Expand All @@ -297,32 +265,49 @@ func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serve
return
}

func (auth *Authenticator) writeAuthHandshake(backendIO *pnet.PacketIO, authData []byte, tls bool) error {
func (auth *Authenticator) writeAuthHandshake(
backendIO *pnet.PacketIO,
backendTLSConfig *tls.Config,
backendCapability pnet.Capability,
authPlugin string,
authData []byte,
authCap uint32,
) error {
// Always handshake with SSL enabled and enable auth_plugin.
resp := &pnet.HandshakeResp{
User: auth.user,
DB: auth.dbname,
AuthPlugin: mysql.AuthTiDBSessionToken,
Attrs: auth.attrs,
AuthData: authData,
Capability: auth.capability | mysql.ClientSSL | mysql.ClientPluginAuth,
Collation: auth.collation,
AuthData: authData,
Capability: auth.capability | authCap,
AuthPlugin: authPlugin,
}
data := pnet.MakeHandshakeResponse(resp)

if tls && auth.backendTLSConfig != nil {
// write SSL req
if err := backendIO.WritePacket(data[:32], true); err != nil {
var pkt []byte
if backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil {
pkt = pnet.MakeHandshakeResponse(resp)
resp.Capability |= mysql.ClientSSL
// write SSL Packet
if err := backendIO.WritePacket(pkt[:32], true); err != nil {
return err
}
// Send TLS / SSL request packet. The server must have supported TLS.
if err := backendIO.ClientTLSHandshake(auth.backendTLSConfig); err != nil {
tcfg := backendTLSConfig.Clone()
addr := backendIO.RemoteAddr().String()
host, _, err := net.SplitHostPort(addr)
if err == nil {
tcfg.ServerName = host
}
if err := backendIO.ClientTLSHandshake(tcfg); err != nil {
return err
}
} else {
pkt = pnet.MakeHandshakeResponse(resp)
}

// write handshake resp
return backendIO.WritePacket(data, true)
return backendIO.WritePacket(pkt, true)
}

func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) error {
Expand Down Expand Up @@ -353,23 +338,3 @@ func (auth *Authenticator) changeUser(username, db string) {
func (auth *Authenticator) updateCurrentDB(db string) {
auth.dbname = db
}

func (auth *Authenticator) ClientAddr() string {
return auth.clientAddr
}

func (auth *Authenticator) ServerAddr() string {
return auth.serverAddr
}

func (auth *Authenticator) SetValue(key, val any) {
auth.ctxmap.Store(key, val)
}

func (auth *Authenticator) Value(key any) any {
v, ok := auth.ctxmap.Load(key)
if !ok {
return nil
}
return v
}
Loading

0 comments on commit 8b4557c

Please sign in to comment.