Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disttask: fix subtask finished immediately and mark success when encountering network partition #48660

Merged
merged 10 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions pkg/ddl/backfilling_dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ import (
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/testkit"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/util"
"go.uber.org/zap"
)

func TestBackfillingDispatcherLocalMode(t *testing.T) {
Expand Down Expand Up @@ -156,7 +154,8 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
}, 1, 1, time.Second)
defer pool.Close()
ctx := context.WithValue(context.Background(), "etcd", true)
mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), pool)
ctx = util.WithInternalSourceType(ctx, "handle")
mgr := storage.NewTaskManager(pool)
storage.SetTaskManager(mgr)
dspManager, err := dispatcher.NewManager(util.WithInternalSourceType(ctx, "dispatcher"), mgr, "host:port")
require.NoError(t, err)
Expand All @@ -175,7 +174,7 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
ext.(*ddl.BackfillingDispatcherExt).GlobalSort = true
dsp.Extension = ext

taskID, err := mgr.AddNewGlobalTask(task.Key, proto.Backfill, 1, task.Meta)
taskID, err := mgr.AddNewGlobalTask(ctx, task.Key, proto.Backfill, 1, task.Meta)
require.NoError(t, err)
task.ID = taskID
serverInfos, _, err := dsp.GetEligibleInstances(context.Background(), task)
Expand All @@ -192,11 +191,10 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
for _, m := range subtaskMetas {
subtasks = append(subtasks, proto.NewSubtask(task.Step, task.ID, task.Type, "", m))
}
_, err = mgr.UpdateGlobalTaskAndAddSubTasks(task, subtasks, proto.TaskStatePending)
_, err = mgr.UpdateGlobalTaskAndAddSubTasks(ctx, task, subtasks, proto.TaskStatePending)
require.NoError(t, err)
gotSubtasks, err := mgr.GetSubtasksForImportInto(taskID, ddl.StepReadIndex)
gotSubtasks, err := mgr.GetSubtasksForImportInto(ctx, taskID, ddl.StepReadIndex)
require.NoError(t, err)
logutil.BgLogger().Info("ywq test", zap.Any("len", len(gotSubtasks)))

// update meta, same as import into.
sortStepMeta := &ddl.BackfillSubTaskMeta{
Expand All @@ -216,7 +214,7 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
sortStepMetaBytes, err := json.Marshal(sortStepMeta)
require.NoError(t, err)
for _, s := range gotSubtasks {
require.NoError(t, mgr.FinishSubtask(s.ID, sortStepMetaBytes))
require.NoError(t, mgr.FinishSubtask(ctx, s.SchedulerID, s.ID, sortStepMetaBytes))
}
// 2. to merge-sort stage.
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/ddl/forceMergeSort", `return()`))
Expand All @@ -234,9 +232,9 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
for _, m := range subtaskMetas {
subtasks = append(subtasks, proto.NewSubtask(task.Step, task.ID, task.Type, "", m))
}
_, err = mgr.UpdateGlobalTaskAndAddSubTasks(task, subtasks, proto.TaskStatePending)
_, err = mgr.UpdateGlobalTaskAndAddSubTasks(ctx, task, subtasks, proto.TaskStatePending)
require.NoError(t, err)
gotSubtasks, err = mgr.GetSubtasksForImportInto(taskID, task.Step)
gotSubtasks, err = mgr.GetSubtasksForImportInto(ctx, taskID, task.Step)
require.NoError(t, err)
mergeSortStepMeta := &ddl.BackfillSubTaskMeta{
SortedKVMeta: external.SortedKVMeta{
Expand All @@ -255,7 +253,7 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
mergeSortStepMetaBytes, err := json.Marshal(mergeSortStepMeta)
require.NoError(t, err)
for _, s := range gotSubtasks {
require.NoError(t, mgr.FinishSubtask(s.ID, mergeSortStepMetaBytes))
require.NoError(t, mgr.FinishSubtask(ctx, s.SchedulerID, s.ID, mergeSortStepMetaBytes))
}
// 3. to write&ingest stage.
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/ddl/mockWriteIngest", "return(true)"))
Expand Down
14 changes: 8 additions & 6 deletions pkg/ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -2060,6 +2060,8 @@ func (w *worker) executeDistGlobalTask(reorgInfo *reorgInfo) error {
taskType := proto.Backfill
taskKey := fmt.Sprintf("ddl/%s/%d", taskType, reorgInfo.Job.ID)
g, ctx := errgroup.WithContext(context.Background())
ctx = kv.WithInternalSourceType(ctx, kv.InternalDistTask)

done := make(chan struct{})

// generate taskKey for multi schema change.
Expand All @@ -2076,7 +2078,7 @@ func (w *worker) executeDistGlobalTask(reorgInfo *reorgInfo) error {
if err != nil {
return err
}
task, err := taskManager.GetGlobalTaskByKeyWithHistory(taskKey)
task, err := taskManager.GetGlobalTaskByKeyWithHistory(w.ctx, taskKey)
if err != nil {
return err
}
Expand All @@ -2095,7 +2097,7 @@ func (w *worker) executeDistGlobalTask(reorgInfo *reorgInfo) error {
backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval)
err := handle.RunWithRetry(ctx, dispatcher.RetrySQLTimes, backoffer, logutil.BgLogger(),
func(ctx context.Context) (bool, error) {
return true, handle.ResumeTask(taskKey)
return true, handle.ResumeTask(w.ctx, taskKey)
},
)
if err != nil {
Expand Down Expand Up @@ -2158,7 +2160,7 @@ func (w *worker) executeDistGlobalTask(reorgInfo *reorgInfo) error {
case <-checkFinishTk.C:
if err = w.isReorgRunnable(reorgInfo.Job.ID, true); err != nil {
if dbterror.ErrPausedDDLJob.Equal(err) {
if err = handle.PauseTask(taskKey); err != nil {
if err = handle.PauseTask(w.ctx, taskKey); err != nil {
logutil.BgLogger().Error("pause global task error", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
continue
}
Expand All @@ -2170,7 +2172,7 @@ func (w *worker) executeDistGlobalTask(reorgInfo *reorgInfo) error {
if !dbterror.ErrCancelledDDLJob.Equal(err) {
return errors.Trace(err)
}
if err = handle.CancelGlobalTask(taskKey); err != nil {
if err = handle.CancelGlobalTask(w.ctx, taskKey); err != nil {
logutil.BgLogger().Error("cancel global task error", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
// continue to cancel global task.
continue
Expand All @@ -2191,12 +2193,12 @@ func (w *worker) updateJobRowCount(taskKey string, jobID int64) {
logutil.BgLogger().Warn("cannot get task manager", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
return
}
gTask, err := taskMgr.GetGlobalTaskByKey(taskKey)
gTask, err := taskMgr.GetGlobalTaskByKey(w.ctx, taskKey)
if err != nil || gTask == nil {
logutil.BgLogger().Warn("cannot get global task", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
return
}
rowCount, err := taskMgr.GetSubtaskRowCount(gTask.ID, proto.StepOne)
rowCount, err := taskMgr.GetSubtaskRowCount(w.ctx, gTask.ID, proto.StepOne)
if err != nil {
logutil.BgLogger().Warn("cannot get subtask row count", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
return
Expand Down
1 change: 1 addition & 0 deletions pkg/disttask/framework/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ go_test(
"//pkg/testkit",
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_stretchr_testify//require",
"@com_github_tikv_client_go_v2//util",
"@org_uber_go_mock//gomock",
],
)
42 changes: 21 additions & 21 deletions pkg/disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (*BaseDispatcher) Close() {

// refreshTask fetch task state from tidb_global_task table.
func (d *BaseDispatcher) refreshTask() error {
newTask, err := d.taskMgr.GetGlobalTaskByID(d.Task.ID)
newTask, err := d.taskMgr.GetGlobalTaskByID(d.ctx, d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Error("refresh task failed", zap.Error(err))
return err
Expand Down Expand Up @@ -166,7 +166,7 @@ func (d *BaseDispatcher) scheduleTask() {
}
failpoint.Inject("cancelTaskAfterRefreshTask", func(val failpoint.Value) {
if val.(bool) && d.Task.State == proto.TaskStateRunning {
err := d.taskMgr.CancelGlobalTask(d.Task.ID)
err := d.taskMgr.CancelGlobalTask(d.ctx, d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Error("cancel task failed", zap.Error(err))
}
Expand All @@ -175,7 +175,7 @@ func (d *BaseDispatcher) scheduleTask() {

failpoint.Inject("pausePendingTask", func(val failpoint.Value) {
if val.(bool) && d.Task.State == proto.TaskStatePending {
_, err := d.taskMgr.PauseTask(d.Task.Key)
_, err := d.taskMgr.PauseTask(d.ctx, d.Task.Key)
if err != nil {
logutil.Logger(d.logCtx).Error("pause task failed", zap.Error(err))
}
Expand All @@ -185,7 +185,7 @@ func (d *BaseDispatcher) scheduleTask() {

failpoint.Inject("pauseTaskAfterRefreshTask", func(val failpoint.Value) {
if val.(bool) && d.Task.State == proto.TaskStateRunning {
_, err := d.taskMgr.PauseTask(d.Task.Key)
_, err := d.taskMgr.PauseTask(d.ctx, d.Task.Key)
if err != nil {
logutil.Logger(d.logCtx).Error("pause task failed", zap.Error(err))
}
Expand Down Expand Up @@ -243,7 +243,7 @@ func (d *BaseDispatcher) onCancelling() error {
// handle task in pausing state, cancel all running subtasks.
func (d *BaseDispatcher) onPausing() error {
logutil.Logger(d.logCtx).Info("on pausing state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStateRunning, proto.TaskStatePending)
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.ctx, d.Task.ID, proto.TaskStateRunning, proto.TaskStatePending)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
Expand Down Expand Up @@ -276,7 +276,7 @@ var TestSyncChan = make(chan struct{})
// handle task in resuming state
func (d *BaseDispatcher) onResuming() error {
logutil.Logger(d.logCtx).Info("on resuming state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStatePaused)
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.ctx, d.Task.ID, proto.TaskStatePaused)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
Expand All @@ -291,13 +291,13 @@ func (d *BaseDispatcher) onResuming() error {
return err
}

return d.taskMgr.ResumeSubtasks(d.Task.ID)
return d.taskMgr.ResumeSubtasks(d.ctx, d.Task.ID)
}

// handle task in reverting state, check all revert subtasks finished.
func (d *BaseDispatcher) onReverting() error {
logutil.Logger(d.logCtx).Debug("on reverting state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStateRevertPending, proto.TaskStateReverting)
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.ctx, d.Task.ID, proto.TaskStateRevertPending, proto.TaskStateReverting)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
Expand All @@ -323,7 +323,7 @@ func (d *BaseDispatcher) onPending() error {
// If subtasks finished, run into the next stage.
func (d *BaseDispatcher) onRunning() error {
logutil.Logger(d.logCtx).Debug("on running state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
subTaskErrs, err := d.taskMgr.CollectSubTaskError(d.Task.ID)
subTaskErrs, err := d.taskMgr.CollectSubTaskError(d.ctx, d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Warn("collect subtask error failed", zap.Error(err))
return err
Expand All @@ -333,7 +333,7 @@ func (d *BaseDispatcher) onRunning() error {
return d.onErrHandlingStage(subTaskErrs)
}
// check current stage finished.
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStatePending, proto.TaskStateRunning)
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.ctx, d.Task.ID, proto.TaskStatePending, proto.TaskStateRunning)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
Expand All @@ -355,13 +355,13 @@ func (d *BaseDispatcher) onRunning() error {
func (d *BaseDispatcher) onFinished() error {
metrics.UpdateMetricsForFinishTask(d.Task)
logutil.Logger(d.logCtx).Debug("schedule task, task is finished", zap.Stringer("state", d.Task.State))
return d.taskMgr.TransferSubTasks2History(d.Task.ID)
return d.taskMgr.TransferSubTasks2History(d.ctx, d.Task.ID)
}

func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
if len(d.taskNodes) == 0 {
var err error
d.taskNodes, err = d.taskMgr.GetSchedulerIDsByTaskIDAndStep(d.Task.ID, d.Task.Step)
d.taskNodes, err = d.taskMgr.GetSchedulerIDsByTaskIDAndStep(d.ctx, d.Task.ID, d.Task.Step)
if err != nil {
return err
}
Expand Down Expand Up @@ -411,10 +411,10 @@ func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
}
if len(replaceNodes) > 0 {
logutil.Logger(d.logCtx).Info("reschedule subtasks to other nodes", zap.Int("node-cnt", len(replaceNodes)))
if err := d.taskMgr.UpdateFailedSchedulerIDs(d.Task.ID, replaceNodes); err != nil {
if err := d.taskMgr.UpdateFailedSchedulerIDs(d.ctx, d.Task.ID, replaceNodes); err != nil {
return err
}
if err := d.taskMgr.CleanUpMeta(cleanNodes); err != nil {
if err := d.taskMgr.CleanUpMeta(d.ctx, cleanNodes); err != nil {
return err
}
// replace local cache.
Expand All @@ -441,15 +441,15 @@ func (d *BaseDispatcher) updateTask(taskState proto.TaskState, newSubTasks []*pr
}

failpoint.Inject("cancelBeforeUpdate", func() {
err := d.taskMgr.CancelGlobalTask(d.Task.ID)
err := d.taskMgr.CancelGlobalTask(d.ctx, d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Error("cancel task failed", zap.Error(err))
}
})

var retryable bool
for i := 0; i < retryTimes; i++ {
retryable, err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(d.Task, newSubTasks, prevState)
retryable, err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(d.ctx, d.Task, newSubTasks, prevState)
if err == nil || !retryable {
break
}
Expand Down Expand Up @@ -658,13 +658,13 @@ func GenerateSchedulerNodes(ctx context.Context) (serverNodes []*infosync.Server
}

func (d *BaseDispatcher) filterByRole(infos []*infosync.ServerInfo) ([]*infosync.ServerInfo, error) {
nodes, err := d.taskMgr.GetNodesByRole("background")
nodes, err := d.taskMgr.GetNodesByRole(d.ctx, "background")
if err != nil {
return nil, err
}

if len(nodes) == 0 {
nodes, err = d.taskMgr.GetNodesByRole("")
nodes, err = d.taskMgr.GetNodesByRole(d.ctx, "")
}

if err != nil {
Expand Down Expand Up @@ -693,7 +693,7 @@ func (d *BaseDispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Tas
return nil, nil
}

schedulerIDs, err := d.taskMgr.GetSchedulerIDsByTaskID(task.ID)
schedulerIDs, err := d.taskMgr.GetSchedulerIDsByTaskID(d.ctx, task.ID)
if err != nil {
return nil, err
}
Expand All @@ -708,7 +708,7 @@ func (d *BaseDispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Tas

// GetPreviousSubtaskMetas get subtask metas from specific step.
func (d *BaseDispatcher) GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) {
previousSubtasks, err := d.taskMgr.GetSucceedSubtasksByStep(taskID, step)
previousSubtasks, err := d.taskMgr.GetSucceedSubtasksByStep(d.ctx, taskID, step)
if err != nil {
logutil.Logger(d.logCtx).Warn("get previous succeed subtask failed", zap.Int64("step", int64(step)))
return nil, err
Expand All @@ -722,7 +722,7 @@ func (d *BaseDispatcher) GetPreviousSubtaskMetas(taskID int64, step proto.Step)

// GetPreviousSchedulerIDs gets scheduler IDs that run previous step.
func (d *BaseDispatcher) GetPreviousSchedulerIDs(_ context.Context, taskID int64, step proto.Step) ([]string, error) {
return d.taskMgr.GetSchedulerIDsByTaskIDAndStep(taskID, step)
return d.taskMgr.GetSchedulerIDsByTaskIDAndStep(d.ctx, taskID, step)
}

// WithNewSession executes the function with a new session.
Expand Down
12 changes: 7 additions & 5 deletions pkg/disttask/framework/dispatcher/dispatcher_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ func (dm *Manager) dispatchTaskLoop() {

// TODO: Consider getting these tasks, in addition to the task being worked on..
tasks, err := dm.taskMgr.GetGlobalTasksInStates(
dm.ctx,
proto.TaskStatePending,
proto.TaskStateRunning,
proto.TaskStateReverting,
Expand Down Expand Up @@ -223,7 +224,7 @@ func (dm *Manager) failTask(task *proto.Task, err error) {
prevState := task.State
task.State = proto.TaskStateFailed
task.Error = err
if _, err2 := dm.taskMgr.UpdateGlobalTaskAndAddSubTasks(task, nil, prevState); err2 != nil {
if _, err2 := dm.taskMgr.UpdateGlobalTaskAndAddSubTasks(dm.ctx, task, nil, prevState); err2 != nil {
logutil.BgLogger().Warn("failed to update task state to failed",
zap.Int64("task-id", task.ID), zap.Error(err2))
}
Expand All @@ -248,7 +249,7 @@ func (dm *Manager) gcSubtaskHistoryTableLoop() {
logutil.BgLogger().Info("subtask history table gc loop exits", zap.Error(dm.ctx.Err()))
return
case <-ticker.C:
err := dm.taskMgr.GCSubtasks()
err := dm.taskMgr.GCSubtasks(dm.ctx)
if err != nil {
logutil.BgLogger().Warn("subtask history table gc failed", zap.Error(err))
} else {
Expand Down Expand Up @@ -318,6 +319,7 @@ func (dm *Manager) doCleanUpRoutine() {
logutil.BgLogger().Info("clean up nodes in framework meta since nodes shutdown", zap.Int("cnt", cnt))
}
tasks, err := dm.taskMgr.GetGlobalTasksInStates(
dm.ctx,
proto.TaskStateFailed,
proto.TaskStateReverted,
proto.TaskStateSucceed,
Expand Down Expand Up @@ -350,7 +352,7 @@ func (dm *Manager) CleanUpMeta() int {
return 0
}

oldNodes, err := dm.taskMgr.GetAllNodes()
oldNodes, err := dm.taskMgr.GetAllNodes(dm.ctx)
if err != nil {
logutil.BgLogger().Warn("get all nodes met error")
return 0
Expand All @@ -366,7 +368,7 @@ func (dm *Manager) CleanUpMeta() int {
return 0
}
logutil.BgLogger().Info("start to clean up dist_framework_meta")
err = dm.taskMgr.CleanUpMeta(cleanNodes)
err = dm.taskMgr.CleanUpMeta(dm.ctx, cleanNodes)
if err != nil {
logutil.BgLogger().Warn("clean up dist_framework_meta met error")
return 0
Expand Down Expand Up @@ -396,7 +398,7 @@ func (dm *Manager) cleanUpFinishedTasks(tasks []*proto.Task) error {
logutil.BgLogger().Warn("cleanUp routine failed", zap.Error(errors.Trace(firstErr)))
}

return dm.taskMgr.TransferTasks2History(cleanedTasks)
return dm.taskMgr.TransferTasks2History(dm.ctx, cleanedTasks)
}

// MockDispatcher mock one dispatcher for one task, only used for tests.
Expand Down
Loading