Skip to content

Commit

Permalink
disttask: skip scheduler take slots for some states (#51022)
Browse files Browse the repository at this point in the history
ref #49008
  • Loading branch information
ywqzzy authored Mar 11, 2024
1 parent f495cc5 commit 8b02143
Show file tree
Hide file tree
Showing 14 changed files with 440 additions and 192 deletions.
6 changes: 5 additions & 1 deletion pkg/disttask/framework/scheduler/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ go_library(
"scheduler_manager.go",
"slots.go",
"state_transform.go",
"testutil.go",
],
importpath = "github.com/pingcap/tidb/pkg/disttask/framework/scheduler",
visibility = ["//visibility:public"],
deps = [
"//br/pkg/lightning/log",
"//pkg/disttask/framework/handle",
"//pkg/disttask/framework/proto",
"//pkg/disttask/framework/scheduler/mock",
"//pkg/disttask/framework/storage",
"//pkg/domain/infosync",
"//pkg/kv",
Expand All @@ -34,6 +36,7 @@ go_library(
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_pingcap_log//:log",
"@com_github_prometheus_client_golang//prometheus",
"@org_uber_go_mock//gomock",
"@org_uber_go_zap//:zap",
],
)
Expand All @@ -45,6 +48,7 @@ go_test(
"balancer_test.go",
"main_test.go",
"nodes_test.go",
"scheduler_manager_nokit_test.go",
"scheduler_manager_test.go",
"scheduler_nokit_test.go",
"scheduler_test.go",
Expand All @@ -53,7 +57,7 @@ go_test(
embed = [":scheduler"],
flaky = True,
race = "off",
shard_count = 31,
shard_count = 33,
deps = [
"//pkg/config",
"//pkg/disttask/framework/mock",
Expand Down
9 changes: 5 additions & 4 deletions pkg/disttask/framework/scheduler/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,11 @@ type Extension interface {

// Param is used to pass parameters when creating scheduler.
type Param struct {
taskMgr TaskManager
nodeMgr *NodeManager
slotMgr *SlotManager
serverID string
taskMgr TaskManager
nodeMgr *NodeManager
slotMgr *SlotManager
serverID string
allocatedSlots bool
}

// schedulerFactoryFn is used to create a scheduler.
Expand Down
29 changes: 28 additions & 1 deletion pkg/disttask/framework/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ var MockOwnerChange func()

// NewBaseScheduler creates a new BaseScheduler.
func NewBaseScheduler(ctx context.Context, task *proto.Task, param Param) *BaseScheduler {
logger := log.L().With(zap.Int64("task-id", task.ID), zap.Stringer("task-type", task.Type))
logger := log.L().With(zap.Int64("task-id", task.ID), zap.Stringer("task-type", task.Type), zap.Bool("allocated-slots", param.allocatedSlots))
if intest.InTest {
logger = logger.With(zap.String("server-id", param.serverID))
}
Expand Down Expand Up @@ -179,6 +179,10 @@ func (s *BaseScheduler) scheduleTask() {
continue
}
task := *s.GetTask()
// TODO: refine failpoints below.
failpoint.Inject("exitScheduler", func() {
failpoint.Return()
})
failpoint.Inject("cancelTaskAfterRefreshTask", func(val failpoint.Value) {
if val.(bool) && task.State == proto.TaskStateRunning {
err := s.taskMgr.CancelTask(s.ctx, task.ID)
Expand Down Expand Up @@ -222,12 +226,35 @@ func (s *BaseScheduler) scheduleTask() {
return
}
case proto.TaskStateResuming:
// Case with 2 nodes.
// Here is the timeline
// 1. task in pausing state.
// 2. node1 and node2 start schedulers with task in pausing state without allocatedSlots.
// 3. node1's scheduler transfer the node from pausing to paused state.
// 4. resume the task.
// 5. node2 scheduler call refreshTask and get task with resuming state.
if !s.allocatedSlots {
s.logger.Info("scheduler exit since not allocated slots", zap.Stringer("state", task.State))
return
}
err = s.onResuming()
case proto.TaskStateReverting:
err = s.onReverting()
case proto.TaskStatePending:
err = s.onPending()
case proto.TaskStateRunning:
// Case with 2 nodes.
// Here is the timeline
// 1. task in pausing state.
// 2. node1 and node2 start schedulers with task in pausing state without allocatedSlots.
// 3. node1's scheduler transfer the node from pausing to paused state.
// 4. resume the task.
// 5. node1 start another scheduler and transfer the node from resuming to running state.
// 6. node2 scheduler call refreshTask and get task with running state.
if !s.allocatedSlots {
s.logger.Info("scheduler exit since not allocated slots", zap.Stringer("state", task.State))
return
}
err = s.onRunning()
case proto.TaskStateSucceed, proto.TaskStateReverted, proto.TaskStateFailed:
s.onFinished()
Expand Down
129 changes: 81 additions & 48 deletions pkg/disttask/framework/scheduler/scheduler_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,52 +216,80 @@ func (sm *Manager) scheduleTaskLoop() {
continue
}

tasks, err := sm.taskMgr.GetTopUnfinishedTasks(sm.ctx)
schedulableTasks, err := sm.getSchedulableTasks()
if err != nil {
sm.logger.Warn("get unfinished tasks failed", zap.Error(err))
continue
}

schedulableTasks := make([]*proto.TaskBase, 0, len(tasks))
for _, task := range tasks {
if sm.hasScheduler(task.ID) {
continue
}
// we check it before start scheduler, so no need to check it again.
// see startScheduler.
// this should not happen normally, unless user modify system table
// directly.
if getSchedulerFactory(task.Type) == nil {
sm.logger.Warn("unknown task type", zap.Int64("task-id", task.ID),
zap.Stringer("task-type", task.Type))
sm.failTask(task.ID, task.State, errors.New("unknown task type"))
continue
}
schedulableTasks = append(schedulableTasks, task)
}
if len(schedulableTasks) == 0 {
err = sm.startSchedulers(schedulableTasks)
if err != nil {
continue
}
}
}

func (sm *Manager) getSchedulableTasks() ([]*proto.TaskBase, error) {
tasks, err := sm.taskMgr.GetTopUnfinishedTasks(sm.ctx)
if err != nil {
sm.logger.Warn("get unfinished tasks failed", zap.Error(err))
return nil, err
}

if err = sm.slotMgr.update(sm.ctx, sm.nodeMgr, sm.taskMgr); err != nil {
sm.logger.Warn("update used slot failed", zap.Error(err))
schedulableTasks := make([]*proto.TaskBase, 0, len(tasks))
for _, task := range tasks {
if sm.hasScheduler(task.ID) {
continue
}
for _, task := range schedulableTasks {
taskCnt = sm.getSchedulerCount()
if taskCnt >= proto.MaxConcurrentTask {
break
}
reservedExecID, ok := sm.slotMgr.canReserve(task)
// we check it before start scheduler, so no need to check it again.
// see startScheduler.
// this should not happen normally, unless user modify system table
// directly.
if getSchedulerFactory(task.Type) == nil {
sm.logger.Warn("unknown task type", zap.Int64("task-id", task.ID),
zap.Stringer("task-type", task.Type))
sm.failTask(task.ID, task.State, errors.New("unknown task type"))
continue
}
schedulableTasks = append(schedulableTasks, task)
}
return schedulableTasks, nil
}

func (sm *Manager) startSchedulers(schedulableTasks []*proto.TaskBase) error {
if len(schedulableTasks) == 0 {
return nil
}
if err := sm.slotMgr.update(sm.ctx, sm.nodeMgr, sm.taskMgr); err != nil {
sm.logger.Warn("update used slot failed", zap.Error(err))
return err
}
for _, task := range schedulableTasks {
taskCnt := sm.getSchedulerCount()
if taskCnt >= proto.MaxConcurrentTask {
break
}
var reservedExecID string
allocateSlots := true
var ok bool
switch task.State {
case proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateResuming:
reservedExecID, ok = sm.slotMgr.canReserve(task)
if !ok {
// task of lower rank might be able to be scheduled.
continue
}
metrics.DistTaskGauge.WithLabelValues(task.Type.String(), metrics.SchedulingStatus).Inc()
metrics.UpdateMetricsForDispatchTask(task.ID, task.Type)
sm.startScheduler(task, reservedExecID)
// reverting/cancelling/pausing
default:
allocateSlots = false
sm.logger.Info("start scheduler without allocating slots",
zap.Int64("task-id", task.ID), zap.Stringer("state", task.State))
}

metrics.DistTaskGauge.WithLabelValues(task.Type.String(), metrics.SchedulingStatus).Inc()
metrics.UpdateMetricsForScheduleTask(task.ID, task.Type)
sm.startScheduler(task, allocateSlots, reservedExecID)
}
return nil
}

func (sm *Manager) failTask(id int64, currState proto.TaskState, err error) {
Expand Down Expand Up @@ -300,7 +328,7 @@ func (sm *Manager) gcSubtaskHistoryTableLoop() {
}
}

func (sm *Manager) startScheduler(basicTask *proto.TaskBase, reservedExecID string) {
func (sm *Manager) startScheduler(basicTask *proto.TaskBase, allocateSlots bool, reservedExecID string) {
task, err := sm.taskMgr.GetTaskByID(sm.ctx, basicTask.ID)
if err != nil {
sm.logger.Error("get task failed", zap.Int64("task-id", basicTask.ID), zap.Error(err))
Expand All @@ -309,24 +337,29 @@ func (sm *Manager) startScheduler(basicTask *proto.TaskBase, reservedExecID stri

schedulerFactory := getSchedulerFactory(task.Type)
scheduler := schedulerFactory(sm.ctx, task, Param{
taskMgr: sm.taskMgr,
nodeMgr: sm.nodeMgr,
slotMgr: sm.slotMgr,
serverID: sm.serverID,
taskMgr: sm.taskMgr,
nodeMgr: sm.nodeMgr,
slotMgr: sm.slotMgr,
serverID: sm.serverID,
allocatedSlots: allocateSlots,
})
if err = scheduler.Init(); err != nil {
sm.logger.Error("init scheduler failed", zap.Error(err))
sm.failTask(task.ID, task.State, err)
return
}
sm.addScheduler(task.ID, scheduler)
sm.slotMgr.reserve(basicTask, reservedExecID)
if allocateSlots {
sm.slotMgr.reserve(basicTask, reservedExecID)
}
sm.logger.Info("task scheduler started", zap.Int64("task-id", task.ID))
sm.schedulerWG.RunWithLog(func() {
defer func() {
scheduler.Close()
sm.delScheduler(task.ID)
sm.slotMgr.unReserve(basicTask, reservedExecID)
if allocateSlots {
sm.slotMgr.unReserve(basicTask, reservedExecID)
}
handle.NotifyTaskChange()
sm.logger.Info("task scheduler exist", zap.Int64("task-id", task.ID))
}()
Expand Down Expand Up @@ -416,16 +449,6 @@ func (sm *Manager) cleanupFinishedTasks(tasks []*proto.Task) error {
return sm.taskMgr.TransferTasks2History(sm.ctx, cleanedTasks)
}

// MockScheduler mock one scheduler for one task, only used for tests.
func (sm *Manager) MockScheduler(task *proto.Task) *BaseScheduler {
return NewBaseScheduler(sm.ctx, task, Param{
taskMgr: sm.taskMgr,
nodeMgr: sm.nodeMgr,
slotMgr: sm.slotMgr,
serverID: sm.serverID,
})
}

func (sm *Manager) collectLoop() {
sm.logger.Info("collect loop start")
ticker := time.NewTicker(defaultCollectMetricsInterval)
Expand All @@ -450,3 +473,13 @@ func (sm *Manager) collect() {

subtaskCollector.subtaskInfo.Store(&subtasks)
}

// MockScheduler mock one scheduler for one task, only used for tests.
func (sm *Manager) MockScheduler(task *proto.Task) *BaseScheduler {
return NewBaseScheduler(sm.ctx, task, Param{
taskMgr: sm.taskMgr,
nodeMgr: sm.nodeMgr,
slotMgr: sm.slotMgr,
serverID: sm.serverID,
})
}
Loading

0 comments on commit 8b02143

Please sign in to comment.