diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 15844e86..432b5e0c 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -29,7 +29,6 @@ import ( gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/TiProxy/lib/util/errors" "github.com/pingcap/TiProxy/lib/util/waitgroup" - "github.com/pingcap/TiProxy/pkg/manager/namespace" "github.com/pingcap/TiProxy/pkg/manager/router" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/tidb/parser/mysql" @@ -84,20 +83,18 @@ type BackendConnManager struct { // cancelFunc is used to cancel the signal processing goroutine. cancelFunc context.CancelFunc backendConn *BackendConnection - nsmgr *namespace.NamespaceManager handshakeHandler HandshakeHandler getBackendIO backendIOGetter connectionID uint64 } // NewBackendConnManager creates a BackendConnManager. -func NewBackendConnManager(logger *zap.Logger, nsmgr *namespace.NamespaceManager, handshakeHandler HandshakeHandler, +func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler, connectionID uint64, proxyProtocol, requireBackendTLS bool) *BackendConnManager { mgr := &BackendConnManager{ logger: logger, connectionID: connectionID, cmdProcessor: NewCmdProcessor(), - nsmgr: nsmgr, handshakeHandler: handshakeHandler, authenticator: &Authenticator{ supportedServerCapabilities: handshakeHandler.GetCapability(), @@ -109,7 +106,7 @@ func NewBackendConnManager(logger *zap.Logger, nsmgr *namespace.NamespaceManager redirectResCh: make(chan *redirectResult, 1), } mgr.getBackendIO = func(auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { - ns, err := handshakeHandler.GetNamespace(nsmgr, resp) + ns, err := handshakeHandler.GetNamespace(resp) if err != nil { return nil, err } diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index 978e64ad..b32c64bb 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -25,14 +25,17 @@ var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil) type HandshakeHandler interface { HandleHandshakeResp(resp *pnet.HandshakeResp, sourceAddr string) error GetCapability() pnet.Capability - GetNamespace(nsMgr *namespace.NamespaceManager, resp *pnet.HandshakeResp) (*namespace.Namespace, error) + GetNamespace(resp *pnet.HandshakeResp) (*namespace.Namespace, error) } type DefaultHandshakeHandler struct { + nsManager *namespace.NamespaceManager } -func NewDefaultHandshakeHandler() *DefaultHandshakeHandler { - return &DefaultHandshakeHandler{} +func NewDefaultHandshakeHandler(nsManager *namespace.NamespaceManager) *DefaultHandshakeHandler { + return &DefaultHandshakeHandler{ + nsManager: nsManager, + } } func (handler *DefaultHandshakeHandler) HandleHandshakeResp(*pnet.HandshakeResp, string) error { @@ -43,10 +46,10 @@ func (handler *DefaultHandshakeHandler) GetCapability() pnet.Capability { return SupportedServerCapabilities } -func (handler *DefaultHandshakeHandler) GetNamespace(nsMgr *namespace.NamespaceManager, resp *pnet.HandshakeResp) (*namespace.Namespace, error) { - ns, ok := nsMgr.GetNamespaceByUser(resp.User) +func (handler *DefaultHandshakeHandler) GetNamespace(resp *pnet.HandshakeResp) (*namespace.Namespace, error) { + ns, ok := handler.nsManager.GetNamespaceByUser(resp.User) if !ok { - ns, ok = nsMgr.GetNamespace("default") + ns, ok = handler.nsManager.GetNamespace("default") } if !ok { return nil, errors.New("failed to find a namespace") diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 3942b84a..6103ee97 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -36,7 +36,7 @@ type proxyConfig struct { func newProxyConfig() *proxyConfig { return &proxyConfig{ - handler: NewDefaultHandshakeHandler(), + handler: NewDefaultHandshakeHandler(nil), capability: defaultTestBackendCapability, sessionToken: mockToken, } @@ -57,7 +57,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { mp := &mockProxy{ proxyConfig: cfg, logger: logger.CreateLoggerForTest(t).Named("mockProxy"), - BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), nil, cfg.handler, 0, false, false), + BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), cfg.handler, 0, false, false), } mp.cmdProcessor.capability = cfg.capability return mp @@ -107,7 +107,7 @@ type CustomHandshakeHandler struct { outAttrs map[string]string } -func (handler *CustomHandshakeHandler) GetNamespace(nsMgr *namespace.NamespaceManager, resp *pnet.HandshakeResp) (*namespace.Namespace, error) { +func (handler *CustomHandshakeHandler) GetNamespace(resp *pnet.HandshakeResp) (*namespace.Namespace, error) { return &namespace.Namespace{}, nil } diff --git a/pkg/proxy/client/client_conn.go b/pkg/proxy/client/client_conn.go index d50fe40c..81b9b871 100644 --- a/pkg/proxy/client/client_conn.go +++ b/pkg/proxy/client/client_conn.go @@ -21,7 +21,6 @@ import ( "net" "github.com/pingcap/TiProxy/lib/util/errors" - "github.com/pingcap/TiProxy/pkg/manager/namespace" "github.com/pingcap/TiProxy/pkg/proxy/backend" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/tidb/parser/mysql" @@ -41,8 +40,8 @@ type ClientConnection struct { } func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *tls.Config, backendTLSConfig *tls.Config, - nsmgr *namespace.NamespaceManager, connID uint64, proxyProtocol, requireBackendTLS bool) *ClientConnection { - bemgr := backend.NewBackendConnManager(logger.Named("be"), nsmgr, backend.NewDefaultHandshakeHandler(), connID, proxyProtocol, requireBackendTLS) + hsHandler backend.HandshakeHandler, connID uint64, proxyProtocol, requireBackendTLS bool) *ClientConnection { + bemgr := backend.NewBackendConnManager(logger.Named("be"), hsHandler, connID, proxyProtocol, requireBackendTLS) opts := make([]pnet.PacketIOption, 0, 2) opts = append(opts, pnet.WithWrapError(ErrClientConn)) if proxyProtocol { diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 800a0d62..9109770f 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -23,8 +23,8 @@ import ( "github.com/pingcap/TiProxy/lib/util/errors" "github.com/pingcap/TiProxy/lib/util/waitgroup" "github.com/pingcap/TiProxy/pkg/manager/cert" - mgrns "github.com/pingcap/TiProxy/pkg/manager/namespace" "github.com/pingcap/TiProxy/pkg/metrics" + "github.com/pingcap/TiProxy/pkg/proxy/backend" "github.com/pingcap/TiProxy/pkg/proxy/client" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "go.uber.org/zap" @@ -43,7 +43,7 @@ type SQLServer struct { listener net.Listener logger *zap.Logger certMgr *cert.CertManager - nsmgr *mgrns.NamespaceManager + hsHandler backend.HandshakeHandler requireBackendTLS bool wg waitgroup.WaitGroup @@ -51,13 +51,13 @@ type SQLServer struct { } // NewSQLServer creates a new SQLServer. -func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, certMgr *cert.CertManager, nsmgr *mgrns.NamespaceManager) (*SQLServer, error) { +func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, certMgr *cert.CertManager, hsHandler backend.HandshakeHandler) (*SQLServer, error) { var err error s := &SQLServer{ logger: logger, certMgr: certMgr, - nsmgr: nsmgr, + hsHandler: hsHandler, requireBackendTLS: cfg.RequireBackendTLS, mu: serverState{ connID: 0, @@ -124,7 +124,7 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) { connID := s.mu.connID s.mu.connID++ logger := s.logger.With(zap.Uint64("connID", connID), zap.String("remoteAddr", conn.RemoteAddr().String())) - clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(), s.nsmgr, connID, s.mu.proxyProtocol, s.requireBackendTLS) + clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(), s.hsHandler, connID, s.mu.proxyProtocol, s.requireBackendTLS) s.mu.clients[connID] = clientConn s.mu.Unlock() diff --git a/pkg/server/server.go b/pkg/server/server.go index f9d006ee..24c6b16e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/TiProxy/pkg/manager/router" "github.com/pingcap/TiProxy/pkg/metrics" "github.com/pingcap/TiProxy/pkg/proxy" + "github.com/pingcap/TiProxy/pkg/proxy/backend" "github.com/pingcap/TiProxy/pkg/sctx" "github.com/pingcap/TiProxy/pkg/server/api" clientv3 "go.etcd.io/etcd/client/v3" @@ -194,7 +195,8 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) // setup proxy server { - srv.Proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg.Proxy, srv.CertManager, srv.NamespaceManager) + hsHandler := backend.NewDefaultHandshakeHandler(srv.NamespaceManager) + srv.Proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg.Proxy, srv.CertManager, hsHandler) if err != nil { err = errors.WithStack(err) return