Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

router: remove connMap #238

Merged
merged 3 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/manager/namespace/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func (mgr *NamespaceManager) buildNamespace(cfg *config.Namespace) (*Namespace,
} else {
fetcher = router.NewStaticFetcher(cfg.Backend.Instances)
}
rt, err := router.NewScoreBasedRouter(logger.Named("router"), mgr.httpCli, fetcher, config.NewDefaultHealthCheckConfig())
if err != nil {
rt := router.NewScoreBasedRouter(logger.Named("router"))
if err := rt.Init(mgr.httpCli, fetcher, config.NewDefaultHealthCheckConfig()); err != nil {
return nil, errors.Errorf("build router error: %w", err)
}
return &Namespace{
Expand Down
4 changes: 2 additions & 2 deletions pkg/manager/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ const (
// RedirectableConn indicates a redirect-able connection.
type RedirectableConn interface {
SetEventReceiver(receiver ConnEventReceiver)
SetValue(key, val any)
Value(key any) any
Redirect(addr string)
GetRedirectingAddr() string
NotifyBackendStatus(status BackendStatus)
ConnectionID() uint64
}

// backendWrapper contains the connections on the backend.
Expand All @@ -86,7 +87,6 @@ type backendWrapper struct {
// A list of *connWrapper and is ordered by the connecting or redirecting time.
// connList and connMap include moving out connections but not moving in connections.
connList *glist.List[*connWrapper]
connMap map[uint64]*glist.Element[*connWrapper]
}

// score calculates the score of the backend. Larger score indicates higher load.
Expand Down
61 changes: 29 additions & 32 deletions pkg/manager/router/router_score.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ import (
"go.uber.org/zap"
)

const (
_routerKey = "__tiproxy_router"
)

var _ Router = &ScoreBasedRouter{}

// ScoreBasedRouter is an implementation of Router interface.
Expand All @@ -43,23 +47,26 @@ type ScoreBasedRouter struct {
}

// NewScoreBasedRouter creates a ScoreBasedRouter.
func NewScoreBasedRouter(logger *zap.Logger, httpCli *http.Client, fetcher BackendFetcher, cfg *config.HealthCheck) (*ScoreBasedRouter, error) {
router := &ScoreBasedRouter{
func NewScoreBasedRouter(logger *zap.Logger) *ScoreBasedRouter {
return &ScoreBasedRouter{
logger: logger,
backends: glist.New[*backendWrapper](),
}
}

func (r *ScoreBasedRouter) Init(httpCli *http.Client, fetcher BackendFetcher, cfg *config.HealthCheck) error {
cfg.Check()
observer, err := StartBackendObserver(logger.Named("observer"), router, httpCli, cfg, fetcher)
observer, err := StartBackendObserver(r.logger.Named("observer"), r, httpCli, cfg, fetcher)
if err != nil {
return nil, err
return err
}
router.observer = observer
r.observer = observer
childCtx, cancelFunc := context.WithCancel(context.Background())
router.cancelFunc = cancelFunc
router.wg.Run(func() {
router.rebalanceLoop(childCtx)
r.cancelFunc = cancelFunc
r.wg.Run(func() {
r.rebalanceLoop(childCtx)
})
return router, nil
return nil
}

// GetBackendSelector implements Router.GetBackendSelector interface.
Expand All @@ -70,6 +77,14 @@ func (router *ScoreBasedRouter) GetBackendSelector() BackendSelector {
}
}

func (router *ScoreBasedRouter) getConnWrapper(conn RedirectableConn) *glist.Element[*connWrapper] {
return conn.Value(_routerKey).(*glist.Element[*connWrapper])
}

func (router *ScoreBasedRouter) setConnWrapper(conn RedirectableConn, ce *glist.Element[*connWrapper]) {
conn.SetValue(_routerKey, ce)
}

func (router *ScoreBasedRouter) routeOnce(excluded []string) (string, error) {
router.Lock()
defer router.Unlock()
Expand Down Expand Up @@ -122,9 +137,7 @@ func (router *ScoreBasedRouter) addNewConn(addr string, conn RedirectableConn) e

func (router *ScoreBasedRouter) removeConn(be *glist.Element[*backendWrapper], ce *glist.Element[*connWrapper]) {
backend := be.Value
conn := ce.Value
backend.connList.Remove(ce)
delete(backend.connMap, conn.ConnectionID())
if !router.removeBackendIfEmpty(be) {
router.adjustBackendList(be)
}
Expand All @@ -133,7 +146,7 @@ func (router *ScoreBasedRouter) removeConn(be *glist.Element[*backendWrapper], c
func (router *ScoreBasedRouter) addConn(be *glist.Element[*backendWrapper], conn *connWrapper) {
backend := be.Value
ce := backend.connList.PushBack(conn)
backend.connMap[conn.ConnectionID()] = ce
router.setConnWrapper(conn, ce)
router.adjustBackendList(be)
}

Expand Down Expand Up @@ -213,12 +226,8 @@ func (router *ScoreBasedRouter) OnRedirectSucceed(from, to string, conn Redirect
return errors.WithStack(errors.Errorf("backend %s is not found in the router", to))
}
toBackend := be.Value
e, ok := toBackend.connMap[conn.ConnectionID()]
if !ok {
return errors.WithStack(errors.Errorf("connection %d is not found on the backend %s", conn.ConnectionID(), to))
}
conn.NotifyBackendStatus(toBackend.status)
connWrapper := e.Value
connWrapper := router.getConnWrapper(conn).Value
connWrapper.phase = phaseRedirectEnd
addMigrateMetrics(from, to, true, connWrapper.lastRedirect)
subBackendConnMetrics(from)
Expand All @@ -234,12 +243,7 @@ func (router *ScoreBasedRouter) OnRedirectFail(from, to string, conn Redirectabl
if be == nil {
return errors.WithStack(errors.Errorf("backend %s is not found in the router", to))
}
toBackend := be.Value
ce, ok := toBackend.connMap[conn.ConnectionID()]
if !ok {
return errors.WithStack(errors.Errorf("connection %d is not found on the backend %s", conn.ConnectionID(), to))
}
router.removeConn(be, ce)
router.removeConn(be, router.getConnWrapper(conn))

be = router.lookupBackend(from, true)
// The backend may have been removed because it's empty. Add it back.
Expand All @@ -248,11 +252,10 @@ func (router *ScoreBasedRouter) OnRedirectFail(from, to string, conn Redirectabl
status: StatusCannotConnect,
addr: from,
connList: glist.New[*connWrapper](),
connMap: make(map[uint64]*glist.Element[*connWrapper]),
})
}
conn.NotifyBackendStatus(be.Value.status)
connWrapper := ce.Value
connWrapper := router.getConnWrapper(conn).Value
connWrapper.phase = phaseRedirectFail
addMigrateMetrics(from, to, false, connWrapper.lastRedirect)
router.addConn(be, connWrapper)
Expand All @@ -273,12 +276,7 @@ func (router *ScoreBasedRouter) OnConnClosed(addr string, conn RedirectableConn)
if be == nil {
return errors.WithStack(errors.Errorf("backend %s is not found in the router", addr))
}
backend := be.Value
ce, ok := backend.connMap[conn.ConnectionID()]
if !ok {
return errors.WithStack(errors.Errorf("connection %d is not found on the backend %s", conn.ConnectionID(), addr))
}
router.removeConn(be, ce)
router.removeConn(be, router.getConnWrapper(conn))
subBackendConnMetrics(addr)
return nil
}
Expand All @@ -297,7 +295,6 @@ func (router *ScoreBasedRouter) OnBackendChanged(backends map[string]BackendStat
status: status,
addr: addr,
connList: glist.New[*connWrapper](),
connMap: make(map[uint64]*glist.Element[*connWrapper]),
})
} else if be != nil {
backend := be.Value
Expand Down
60 changes: 34 additions & 26 deletions pkg/manager/router/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"testing"
"time"

glist "github.com/bahlo/generic-list-go"
"github.com/pingcap/TiProxy/lib/config"
"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/TiProxy/lib/util/logger"
Expand All @@ -35,18 +34,40 @@ import (
type mockRedirectableConn struct {
sync.Mutex
t *testing.T
kv map[any]any
connID uint64
from, to string
status BackendStatus
receiver ConnEventReceiver
}

func newMockRedirectableConn(t *testing.T, id uint64) *mockRedirectableConn {
return &mockRedirectableConn{
t: t,
connID: id,
kv: make(map[any]any),
}
}

func (conn *mockRedirectableConn) SetEventReceiver(receiver ConnEventReceiver) {
conn.Lock()
conn.receiver = receiver
conn.Unlock()
}

func (conn *mockRedirectableConn) SetValue(k, v any) {
conn.Lock()
conn.kv[k] = v
conn.Unlock()
}

func (conn *mockRedirectableConn) Value(k any) any {
conn.Lock()
v := conn.kv[k]
conn.Unlock()
return v
}

func (conn *mockRedirectableConn) Redirect(addr string) {
conn.Lock()
require.Len(conn.t, conn.to, 0)
Expand Down Expand Up @@ -100,10 +121,7 @@ type routerTester struct {
}

func newRouterTester(t *testing.T) *routerTester {
router := &ScoreBasedRouter{
logger: logger.CreateLoggerForTest(t),
backends: glist.New[*backendWrapper](),
}
router := NewScoreBasedRouter(logger.CreateLoggerForTest(t))
t.Cleanup(router.Close)
return &routerTester{
t: t,
Expand All @@ -114,10 +132,7 @@ func newRouterTester(t *testing.T) *routerTester {

func (tester *routerTester) createConn() *mockRedirectableConn {
tester.connID++
return &mockRedirectableConn{
t: tester.t,
connID: tester.connID,
}
return newMockRedirectableConn(tester.t, tester.connID)
}

func (tester *routerTester) addBackends(num int) {
Expand Down Expand Up @@ -306,7 +321,7 @@ func (tester *routerTester) checkBackendConnMetrics() {

func (tester *routerTester) clear() {
tester.conns = make(map[uint64]*mockRedirectableConn)
tester.router.backends = glist.New[*backendWrapper]()
tester.router.backends.Init()
}

// Test that the backends are always ordered by scores.
Expand Down Expand Up @@ -506,7 +521,7 @@ func TestRebalanceCornerCase(t *testing.T) {
tester.checkRedirectingNum(10)
tester.checkBackendNum(1)
backend := tester.getBackendByIndex(0)
require.Len(t, backend.connMap, 10)
require.Equal(t, 10, backend.connList.Len())
},
func() {
// Connections won't be redirected again before redirection finishes.
Expand All @@ -521,7 +536,7 @@ func TestRebalanceCornerCase(t *testing.T) {
tester.addBackends(1)
tester.rebalance(10)
tester.checkRedirectingNum(10)
require.Len(t, backend.connMap, 10)
require.Equal(t, 10, backend.connList.Len())
},
func() {
// After redirection fails, the connections are moved back to the unhealthy backends.
Expand Down Expand Up @@ -587,7 +602,8 @@ func TestConcurrency(t *testing.T) {
client := createEtcdClient(t, etcd)
healthCheckConfig := newHealthCheckConfigForTest()
fetcher := NewPDFetcher(client, logger.CreateLoggerForTest(t), healthCheckConfig)
router, err := NewScoreBasedRouter(logger.CreateLoggerForTest(t), nil, fetcher, healthCheckConfig)
router := NewScoreBasedRouter(logger.CreateLoggerForTest(t))
err := router.Init(nil, fetcher, healthCheckConfig)
require.NoError(t, err)

var wg waitgroup.WaitGroup
Expand Down Expand Up @@ -637,10 +653,7 @@ func TestConcurrency(t *testing.T) {

if conn == nil {
// not connected, connect
conn = &mockRedirectableConn{
t: t,
connID: connID,
}
conn = newMockRedirectableConn(t, connID)
selector := router.GetBackendSelector()
addr, err := selector.Next()
require.NoError(t, err)
Expand Down Expand Up @@ -698,10 +711,7 @@ func TestRefresh(t *testing.T) {
})
// Create a router with a very long health check interval.
lg := logger.CreateLoggerForTest(t)
rt := &ScoreBasedRouter{
logger: lg,
backends: glist.New[*backendWrapper](),
}
rt := NewScoreBasedRouter(lg)
cfg := config.NewDefaultHealthCheckConfig()
cfg.Interval = time.Minute
observer, err := StartBackendObserver(lg, rt, nil, cfg, fetcher)
Expand Down Expand Up @@ -740,10 +750,7 @@ func TestObserveError(t *testing.T) {
})
// Create a router with a very short health check interval.
lg := logger.CreateLoggerForTest(t)
rt := &ScoreBasedRouter{
logger: lg,
backends: glist.New[*backendWrapper](),
}
rt := NewScoreBasedRouter(lg)
observer, err := StartBackendObserver(lg, rt, nil, newHealthCheckConfigForTest(), fetcher)
require.NoError(t, err)
rt.Lock()
Expand Down Expand Up @@ -798,7 +805,8 @@ func TestDisableHealthCheck(t *testing.T) {
})
// Create a router with a very short health check interval.
lg := logger.CreateLoggerForTest(t)
rt, err := NewScoreBasedRouter(lg, nil, fetcher, &config.HealthCheck{Enable: false})
rt := NewScoreBasedRouter(lg)
err := rt.Init(nil, fetcher, &config.HealthCheck{Enable: false})
require.NoError(t, err)
defer rt.Close()
// No backends and no error.
Expand Down