From 3a137850a93a26e9b67222dd4c2a58ccd671e7d2 Mon Sep 17 00:00:00 2001 From: GMHDBJD <35025882+GMHDBJD@users.noreply.github.com> Date: Wed, 21 Jun 2023 21:41:42 +0800 Subject: [PATCH] This is an automated cherry-pick of #44803 Signed-off-by: ti-chi-bot --- br/pkg/checksum/executor.go | 20 +- br/pkg/lightning/backend/local/BUILD.bazel | 2 + br/pkg/lightning/backend/local/checksum.go | 45 +- .../lightning/backend/local/checksum_test.go | 2 + br/pkg/lightning/common/BUILD.bazel | 5 + br/pkg/lightning/common/common.go | 1 + br/pkg/lightning/common/util.go | 163 +++++ br/pkg/lightning/config/config.go | 2 + br/pkg/lightning/importer/checksum_helper.go | 13 +- .../lightning/importer/table_import_test.go | 3 +- br/pkg/lightning/importer/tidb_test.go | 2 + br/tests/lightning_add_index/config1.toml | 3 + disttask/framework/dispatcher/dispatcher.go | 29 + disttask/framework/storage/task_table.go | 75 +- disttask/importinto/BUILD.bazel | 80 +++ disttask/importinto/dispatcher.go | 647 ++++++++++++++++++ disttask/importinto/job.go | 279 ++++++++ disttask/importinto/subtask_executor.go | 240 +++++++ disttask/loaddata/subtask_executor_test.go | 73 ++ executor/import_into.go | 302 ++++++++ executor/importer/BUILD.bazel | 5 + executor/importer/table_import.go | 9 +- tests/realtikvtest/importintotest/job_test.go | 635 +++++++++++++++++ 23 files changed, 2618 insertions(+), 17 deletions(-) create mode 100644 disttask/importinto/BUILD.bazel create mode 100644 disttask/importinto/dispatcher.go create mode 100644 disttask/importinto/job.go create mode 100644 disttask/importinto/subtask_executor.go create mode 100644 disttask/loaddata/subtask_executor_test.go create mode 100644 executor/import_into.go create mode 100644 tests/realtikvtest/importintotest/job_test.go diff --git a/br/pkg/checksum/executor.go b/br/pkg/checksum/executor.go index 0159fe43b1d0c..409d40a3530a6 100644 --- a/br/pkg/checksum/executor.go +++ b/br/pkg/checksum/executor.go @@ -28,7 +28,8 @@ type ExecutorBuilder struct { oldTable *metautil.Table - concurrency uint + concurrency uint + backoffWeight int oldKeyspace []byte newKeyspace []byte @@ -56,6 +57,12 @@ func (builder *ExecutorBuilder) SetConcurrency(conc uint) *ExecutorBuilder { return builder } +// SetBackoffWeight set the backoffWeight of the checksum executing. +func (builder *ExecutorBuilder) SetBackoffWeight(backoffWeight int) *ExecutorBuilder { + builder.backoffWeight = backoffWeight + return builder +} + func (builder *ExecutorBuilder) SetOldKeyspace(keyspace []byte) *ExecutorBuilder { builder.oldKeyspace = keyspace return builder @@ -79,7 +86,7 @@ func (builder *ExecutorBuilder) Build() (*Executor, error) { if err != nil { return nil, errors.Trace(err) } - return &Executor{reqs: reqs}, nil + return &Executor{reqs: reqs, backoffWeight: builder.backoffWeight}, nil } func buildChecksumRequest( @@ -294,7 +301,8 @@ func updateChecksumResponse(resp, update *tipb.ChecksumResponse) { // Executor is a checksum executor. type Executor struct { - reqs []*kv.Request + reqs []*kv.Request + backoffWeight int } // Len returns the total number of checksum requests. @@ -347,7 +355,11 @@ func (exec *Executor) Execute( err error ) err = utils.WithRetry(ctx, func() error { - resp, err = sendChecksumRequest(ctx, client, req, kv.NewVariables(&killed)) + vars := kv.NewVariables(&killed) + if exec.backoffWeight > 0 { + vars.BackOffWeight = exec.backoffWeight + } + resp, err = sendChecksumRequest(ctx, client, req, vars) failpoint.Inject("checksumRetryErr", func(val failpoint.Value) { // first time reach here. return error if val.(bool) { diff --git a/br/pkg/lightning/backend/local/BUILD.bazel b/br/pkg/lightning/backend/local/BUILD.bazel index 211539e8ca73d..95f344c2f3c9c 100644 --- a/br/pkg/lightning/backend/local/BUILD.bazel +++ b/br/pkg/lightning/backend/local/BUILD.bazel @@ -47,6 +47,7 @@ go_library( "//kv", "//parser/model", "//parser/mysql", + "//sessionctx/variable", "//store/pdtypes", "//table", "//tablecodec", @@ -72,6 +73,7 @@ go_library( "@com_github_pingcap_kvproto//pkg/pdpb", "@com_github_pingcap_tipb//go-tipb", "@com_github_tikv_client_go_v2//error", + "@com_github_tikv_client_go_v2//kv", "@com_github_tikv_client_go_v2//oracle", "@com_github_tikv_client_go_v2//tikv", "@com_github_tikv_pd_client//:client", diff --git a/br/pkg/lightning/backend/local/checksum.go b/br/pkg/lightning/backend/local/checksum.go index 7c7c45514674a..310f3f0020c70 100644 --- a/br/pkg/lightning/backend/local/checksum.go +++ b/br/pkg/lightning/backend/local/checksum.go @@ -32,8 +32,10 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/metric" "github.com/pingcap/tidb/br/pkg/lightning/verification" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/mathutil" "github.com/pingcap/tipb/go-tipb" + tikvstore "github.com/tikv/client-go/v2/kv" "github.com/tikv/client-go/v2/oracle" pd "github.com/tikv/pd/client" "go.uber.org/atomic" @@ -49,7 +51,14 @@ const ( var ( serviceSafePointTTL int64 = 10 * 60 // 10 min in seconds - minDistSQLScanConcurrency = 4 + // MinDistSQLScanConcurrency is the minimum value of tidb_distsql_scan_concurrency. + MinDistSQLScanConcurrency = 4 + + // DefaultBackoffWeight is the default value of tidb_backoff_weight for checksum. + // when TiKV client encounters an error of "region not leader", it will keep retrying every 500 ms. + // If it still fails after 2 * 20 = 40 seconds, it will return "region unavailable". + // If we increase the BackOffWeight to 6, then the TiKV client will keep retrying for 120 seconds. + DefaultBackoffWeight = 3 * tikvstore.DefBackOffWeight ) // RemoteChecksum represents a checksum result got from tidb. @@ -102,6 +111,15 @@ func (e *tidbChecksumExecutor) Checksum(ctx context.Context, tableInfo *checkpoi task := log.FromContext(ctx).With(zap.String("table", tableName)).Begin(zap.InfoLevel, "remote checksum") + conn, err := e.db.Conn(ctx) + if err != nil { + return nil, errors.Trace(err) + } + defer func() { + if err := conn.Close(); err != nil { + task.Warn("close connection failed", zap.Error(err)) + } + }() // ADMIN CHECKSUM TABLE ,
example. // mysql> admin checksum table test.t; // +---------+------------+---------------------+-----------+-------------+ @@ -109,9 +127,23 @@ func (e *tidbChecksumExecutor) Checksum(ctx context.Context, tableInfo *checkpoi // +---------+------------+---------------------+-----------+-------------+ // | test | t | 8520875019404689597 | 7296873 | 357601387 | // +---------+------------+---------------------+-----------+-------------+ + backoffWeight, err := common.GetBackoffWeightFromDB(ctx, e.db) + if err == nil && backoffWeight < DefaultBackoffWeight { + task.Info("increase tidb_backoff_weight", zap.Int("original", backoffWeight), zap.Int("new", DefaultBackoffWeight)) + // increase backoff weight + if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION %s = '%d';", variable.TiDBBackOffWeight, DefaultBackoffWeight)); err != nil { + task.Warn("set tidb_backoff_weight failed", zap.Error(err)) + } else { + defer func() { + if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION %s = '%d';", variable.TiDBBackOffWeight, backoffWeight)); err != nil { + task.Warn("recover tidb_backoff_weight failed", zap.Error(err)) + } + }() + } + } cs := RemoteChecksum{} - err = common.SQLWithRetry{DB: e.db, Logger: task.Logger}.QueryRow(ctx, "compute remote checksum", + err = common.SQLWithRetry{DB: conn, Logger: task.Logger}.QueryRow(ctx, "compute remote checksum", "ADMIN CHECKSUM TABLE "+tableName, &cs.Schema, &cs.Table, &cs.Checksum, &cs.TotalKVs, &cs.TotalBytes, ) dur := task.End(zap.ErrorLevel, err) @@ -239,22 +271,25 @@ type TiKVChecksumManager struct { client kv.Client manager gcTTLManager distSQLScanConcurrency uint + backoffWeight int } var _ ChecksumManager = &TiKVChecksumManager{} // NewTiKVChecksumManager return a new tikv checksum manager -func NewTiKVChecksumManager(client kv.Client, pdClient pd.Client, distSQLScanConcurrency uint) *TiKVChecksumManager { +func NewTiKVChecksumManager(client kv.Client, pdClient pd.Client, distSQLScanConcurrency uint, backoffWeight int) *TiKVChecksumManager { return &TiKVChecksumManager{ client: client, manager: newGCTTLManager(pdClient), distSQLScanConcurrency: distSQLScanConcurrency, + backoffWeight: backoffWeight, } } func (e *TiKVChecksumManager) checksumDB(ctx context.Context, tableInfo *checkpoints.TidbTableInfo, ts uint64) (*RemoteChecksum, error) { executor, err := checksum.NewExecutorBuilder(tableInfo.Core, ts). SetConcurrency(e.distSQLScanConcurrency). + SetBackoffWeight(e.backoffWeight). Build() if err != nil { return nil, errors.Trace(err) @@ -286,8 +321,8 @@ func (e *TiKVChecksumManager) checksumDB(ctx context.Context, tableInfo *checkpo if !common.IsRetryableError(err) { break } - if distSQLScanConcurrency > minDistSQLScanConcurrency { - distSQLScanConcurrency = mathutil.Max(distSQLScanConcurrency/2, minDistSQLScanConcurrency) + if distSQLScanConcurrency > MinDistSQLScanConcurrency { + distSQLScanConcurrency = mathutil.Max(distSQLScanConcurrency/2, MinDistSQLScanConcurrency) } } diff --git a/br/pkg/lightning/backend/local/checksum_test.go b/br/pkg/lightning/backend/local/checksum_test.go index d00387157403e..eae3f7909e3eb 100644 --- a/br/pkg/lightning/backend/local/checksum_test.go +++ b/br/pkg/lightning/backend/local/checksum_test.go @@ -48,6 +48,7 @@ func TestDoChecksum(t *testing.T) { WithArgs("10m"). WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectClose() + mock.ExpectClose() manager := NewTiDBChecksumExecutor(db) checksum, err := manager.Checksum(context.Background(), &TidbTableInfo{DB: "test", Name: "t"}) @@ -215,6 +216,7 @@ func TestDoChecksumWithErrorAndLongOriginalLifetime(t *testing.T) { WithArgs("300h"). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectClose() + mock.ExpectClose() manager := NewTiDBChecksumExecutor(db) _, err = manager.Checksum(context.Background(), &TidbTableInfo{DB: "test", Name: "t"}) diff --git a/br/pkg/lightning/common/BUILD.bazel b/br/pkg/lightning/common/BUILD.bazel index af6a2f7095d2b..335ceca0b4f7d 100644 --- a/br/pkg/lightning/common/BUILD.bazel +++ b/br/pkg/lightning/common/BUILD.bazel @@ -26,6 +26,11 @@ go_library( "//kv", "//meta/autoid", "//parser/model", +<<<<<<< HEAD +======= + "//parser/mysql", + "//sessionctx/variable", +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) "//store/driver/error", "//table/tables", "//util", diff --git a/br/pkg/lightning/common/common.go b/br/pkg/lightning/common/common.go index 4a9e0f461a477..aaf8860e4fb58 100644 --- a/br/pkg/lightning/common/common.go +++ b/br/pkg/lightning/common/common.go @@ -39,6 +39,7 @@ var DefaultImportantVariables = map[string]string{ "default_week_format": "0", "block_encryption_mode": "aes-128-ecb", "group_concat_max_len": "1024", + "tidb_backoff_weight": "6", } // DefaultImportVariablesTiDB is used in ObtainImportantVariables to retrieve the system diff --git a/br/pkg/lightning/common/util.go b/br/pkg/lightning/common/util.go index 7a0b4b8095582..ed2868f904ad2 100644 --- a/br/pkg/lightning/common/util.go +++ b/br/pkg/lightning/common/util.go @@ -37,6 +37,11 @@ import ( "github.com/pingcap/tidb/br/pkg/utils" tmysql "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser/model" +<<<<<<< HEAD +======= + tmysql "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/sessionctx/variable" +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/util/codec" "go.uber.org/zap" @@ -488,3 +493,161 @@ func GetAutoRandomColumn(tblInfo *model.TableInfo) *model.ColumnInfo { } return nil } +<<<<<<< HEAD +======= + +// GetDropIndexInfos returns the index infos that need to be dropped and the remain indexes. +func GetDropIndexInfos( + tblInfo *model.TableInfo, +) (remainIndexes []*model.IndexInfo, dropIndexes []*model.IndexInfo) { + cols := tblInfo.Columns +loop: + for _, idxInfo := range tblInfo.Indices { + if idxInfo.State != model.StatePublic { + remainIndexes = append(remainIndexes, idxInfo) + continue + } + // Primary key is a cluster index. + if idxInfo.Primary && tblInfo.HasClusteredIndex() { + remainIndexes = append(remainIndexes, idxInfo) + continue + } + // Skip index that contains auto-increment column. + // Because auto colum must be defined as a key. + for _, idxCol := range idxInfo.Columns { + flag := cols[idxCol.Offset].GetFlag() + if tmysql.HasAutoIncrementFlag(flag) { + remainIndexes = append(remainIndexes, idxInfo) + continue loop + } + } + dropIndexes = append(dropIndexes, idxInfo) + } + return remainIndexes, dropIndexes +} + +// BuildDropIndexSQL builds the SQL statement to drop index. +func BuildDropIndexSQL(tableName string, idxInfo *model.IndexInfo) string { + if idxInfo.Primary { + return fmt.Sprintf("ALTER TABLE %s DROP PRIMARY KEY", tableName) + } + return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", tableName, EscapeIdentifier(idxInfo.Name.O)) +} + +// BuildAddIndexSQL builds the SQL statement to create missing indexes. +// It returns both a single SQL statement that creates all indexes at once, +// and a list of SQL statements that creates each index individually. +func BuildAddIndexSQL( + tableName string, + curTblInfo, + desiredTblInfo *model.TableInfo, +) (singleSQL string, multiSQLs []string) { + addIndexSpecs := make([]string, 0, len(desiredTblInfo.Indices)) +loop: + for _, desiredIdxInfo := range desiredTblInfo.Indices { + for _, curIdxInfo := range curTblInfo.Indices { + if curIdxInfo.Name.L == desiredIdxInfo.Name.L { + continue loop + } + } + + var buf bytes.Buffer + if desiredIdxInfo.Primary { + buf.WriteString("ADD PRIMARY KEY ") + } else if desiredIdxInfo.Unique { + buf.WriteString("ADD UNIQUE KEY ") + } else { + buf.WriteString("ADD KEY ") + } + // "primary" is a special name for primary key, we should not use it as index name. + if desiredIdxInfo.Name.L != "primary" { + buf.WriteString(EscapeIdentifier(desiredIdxInfo.Name.O)) + } + + colStrs := make([]string, 0, len(desiredIdxInfo.Columns)) + for _, col := range desiredIdxInfo.Columns { + var colStr string + if desiredTblInfo.Columns[col.Offset].Hidden { + colStr = fmt.Sprintf("(%s)", desiredTblInfo.Columns[col.Offset].GeneratedExprString) + } else { + colStr = EscapeIdentifier(col.Name.O) + if col.Length != types.UnspecifiedLength { + colStr = fmt.Sprintf("%s(%s)", colStr, strconv.Itoa(col.Length)) + } + } + colStrs = append(colStrs, colStr) + } + fmt.Fprintf(&buf, "(%s)", strings.Join(colStrs, ",")) + + if desiredIdxInfo.Invisible { + fmt.Fprint(&buf, " INVISIBLE") + } + if desiredIdxInfo.Comment != "" { + fmt.Fprintf(&buf, ` COMMENT '%s'`, format.OutputFormat(desiredIdxInfo.Comment)) + } + addIndexSpecs = append(addIndexSpecs, buf.String()) + } + if len(addIndexSpecs) == 0 { + return "", nil + } + + singleSQL = fmt.Sprintf("ALTER TABLE %s %s", tableName, strings.Join(addIndexSpecs, ", ")) + for _, spec := range addIndexSpecs { + multiSQLs = append(multiSQLs, fmt.Sprintf("ALTER TABLE %s %s", tableName, spec)) + } + return singleSQL, multiSQLs +} + +// IsDupKeyError checks if err is a duplicate index error. +func IsDupKeyError(err error) bool { + if merr, ok := errors.Cause(err).(*mysql.MySQLError); ok { + switch merr.Number { + case errno.ErrDupKeyName, errno.ErrMultiplePriKey, errno.ErrDupUnique: + return true + } + } + return false +} + +// GetBackoffWeightFromDB gets the backoff weight from database. +func GetBackoffWeightFromDB(ctx context.Context, db *sql.DB) (int, error) { + val, err := getSessionVariable(ctx, db, variable.TiDBBackOffWeight) + if err != nil { + return 0, err + } + return strconv.Atoi(val) +} + +// copy from dbutil to avoid import cycle +func getSessionVariable(ctx context.Context, db *sql.DB, variable string) (value string, err error) { + query := fmt.Sprintf("SHOW VARIABLES LIKE '%s'", variable) + rows, err := db.QueryContext(ctx, query) + + if err != nil { + return "", errors.Trace(err) + } + defer rows.Close() + + // Show an example. + /* + mysql> SHOW VARIABLES LIKE "binlog_format"; + +---------------+-------+ + | Variable_name | Value | + +---------------+-------+ + | binlog_format | ROW | + +---------------+-------+ + */ + + for rows.Next() { + if err = rows.Scan(&variable, &value); err != nil { + return "", errors.Trace(err) + } + } + + if err := rows.Err(); err != nil { + return "", errors.Trace(err) + } + + return value, nil +} +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) diff --git a/br/pkg/lightning/config/config.go b/br/pkg/lightning/config/config.go index 801a05e07e553..577b563f76352 100644 --- a/br/pkg/lightning/config/config.go +++ b/br/pkg/lightning/config/config.go @@ -592,6 +592,7 @@ type PostRestore struct { Level1Compact bool `toml:"level-1-compact" json:"level-1-compact"` PostProcessAtLast bool `toml:"post-process-at-last" json:"post-process-at-last"` Compact bool `toml:"compact" json:"compact"` + ChecksumViaSQL bool `toml:"checksum-via-sql" json:"checksum-via-sql"` } // StringOrStringSlice can unmarshal a TOML string as string slice with one element. @@ -965,6 +966,7 @@ func NewConfig() *Config { Checksum: OpLevelRequired, Analyze: OpLevelOptional, PostProcessAtLast: true, + ChecksumViaSQL: true, }, } } diff --git a/br/pkg/lightning/importer/checksum_helper.go b/br/pkg/lightning/importer/checksum_helper.go index e81703cfce85b..88bc40d5a72e1 100644 --- a/br/pkg/lightning/importer/checksum_helper.go +++ b/br/pkg/lightning/importer/checksum_helper.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/br/pkg/lightning/backend/local" "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/config" "github.com/pingcap/tidb/br/pkg/lightning/log" "github.com/pingcap/tidb/br/pkg/lightning/metric" @@ -44,14 +45,22 @@ func NewChecksumManager(ctx context.Context, rc *Controller, store kv.Storage) ( // for v4.0.0 or upper, we can use the gc ttl api var manager local.ChecksumManager - if pdVersion.Major >= 4 { + if pdVersion.Major >= 4 && !rc.cfg.PostRestore.ChecksumViaSQL { tlsOpt := rc.tls.ToPDSecurityOption() pdCli, err := pd.NewClientWithContext(ctx, []string{pdAddr}, tlsOpt) if err != nil { return nil, errors.Trace(err) } - manager = local.NewTiKVChecksumManager(store.GetClient(), pdCli, uint(rc.cfg.TiDB.DistSQLScanConcurrency)) + backoffWeight, err := common.GetBackoffWeightFromDB(ctx, rc.db) + // only set backoff weight when it's smaller than default value + if err == nil && backoffWeight >= local.DefaultBackoffWeight { + log.FromContext(ctx).Info("get tidb_backoff_weight", zap.Int("backoff_weight", backoffWeight)) + } else { + log.FromContext(ctx).Info("set tidb_backoff_weight to default", zap.Int("backoff_weight", local.DefaultBackoffWeight)) + backoffWeight = local.DefaultBackoffWeight + } + manager = local.NewTiKVChecksumManager(store.GetClient(), pdCli, uint(rc.cfg.TiDB.DistSQLScanConcurrency), backoffWeight) } else { manager = local.NewTiDBChecksumExecutor(rc.db) } diff --git a/br/pkg/lightning/importer/table_import_test.go b/br/pkg/lightning/importer/table_import_test.go index b061e9624d68e..01ccf311f0e65 100644 --- a/br/pkg/lightning/importer/table_import_test.go +++ b/br/pkg/lightning/importer/table_import_test.go @@ -803,6 +803,7 @@ func (s *tableRestoreSuite) TestCompareChecksumSuccess() { WithArgs("10m"). WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectClose() + mock.ExpectClose() ctx := MockDoChecksumCtx(db) remoteChecksum, err := DoChecksum(ctx, s.tr.tableInfo) @@ -833,7 +834,7 @@ func (s *tableRestoreSuite) TestCompareChecksumFailure() { WithArgs("10m"). WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectClose() - + mock.ExpectClose() ctx := MockDoChecksumCtx(db) remoteChecksum, err := DoChecksum(ctx, s.tr.tableInfo) require.NoError(s.T(), err) diff --git a/br/pkg/lightning/importer/tidb_test.go b/br/pkg/lightning/importer/tidb_test.go index 64813d669ec7b..4c0c33e6efc1b 100644 --- a/br/pkg/lightning/importer/tidb_test.go +++ b/br/pkg/lightning/importer/tidb_test.go @@ -337,6 +337,7 @@ func TestObtainRowFormatVersionSucceed(t *testing.T) { sysVars := ObtainImportantVariables(ctx, s.db, true) require.Equal(t, map[string]string{ + "tidb_backoff_weight": "6", "tidb_row_format_version": "2", "max_allowed_packet": "1073741824", "div_precision_increment": "10", @@ -360,6 +361,7 @@ func TestObtainRowFormatVersionFailure(t *testing.T) { sysVars := ObtainImportantVariables(ctx, s.db, true) require.Equal(t, map[string]string{ + "tidb_backoff_weight": "6", "tidb_row_format_version": "1", "max_allowed_packet": "67108864", "div_precision_increment": "4", diff --git a/br/tests/lightning_add_index/config1.toml b/br/tests/lightning_add_index/config1.toml index 2391884fb6a56..36b03d49a1117 100644 --- a/br/tests/lightning_add_index/config1.toml +++ b/br/tests/lightning_add_index/config1.toml @@ -1,3 +1,6 @@ [tikv-importer] backend = 'local' add-index-by-sql = false + +[post-restore] +checksum-via-sql = false \ No newline at end of file diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go index eb6de9c96e854..e399abc461b16 100644 --- a/disttask/framework/dispatcher/dispatcher.go +++ b/disttask/framework/dispatcher/dispatcher.go @@ -483,12 +483,41 @@ func (d *dispatcher) GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]s return ids, nil } +<<<<<<< HEAD func matchServerInfo(serverInfos map[string]*infosync.ServerInfo, schedulerID string) bool { for _, serverInfo := range serverInfos { serverID := disttaskutil.GenerateExecID(serverInfo.IP, serverInfo.Port) if serverID == schedulerID { return true } +======= +func (d *dispatcher) GetPreviousSubtaskMetas(gTaskID int64, step int64) ([][]byte, error) { + previousSubtasks, err := d.taskMgr.GetSucceedSubtasksByStep(gTaskID, step) + if err != nil { + logutil.BgLogger().Warn("get previous succeed subtask failed", zap.Int64("ID", gTaskID), zap.Int64("step", step)) + return nil, err + } + previousSubtaskMetas := make([][]byte, 0, len(previousSubtasks)) + for _, subtask := range previousSubtasks { + previousSubtaskMetas = append(previousSubtaskMetas, subtask.Meta) + } + return previousSubtaskMetas, nil +} + +func (d *dispatcher) WithNewSession(fn func(se sessionctx.Context) error) error { + return d.taskMgr.WithNewSession(fn) +} + +func (d *dispatcher) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { + return d.taskMgr.WithNewTxn(ctx, fn) +} + +func (*dispatcher) checkConcurrencyOverflow(cnt int) bool { + if cnt >= DefaultDispatchConcurrency { + logutil.BgLogger().Info("dispatch task loop, running GTask cnt is more than concurrency", + zap.Int("running cnt", cnt), zap.Int("concurrency", DefaultDispatchConcurrency)) + return true +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) } return false } diff --git a/disttask/framework/storage/task_table.go b/disttask/framework/storage/task_table.go index be54c9054eb94..9beab9a8852f8 100644 --- a/disttask/framework/storage/task_table.go +++ b/disttask/framework/storage/task_table.go @@ -34,6 +34,17 @@ import ( "go.uber.org/zap" ) +<<<<<<< HEAD +======= +// SessionExecutor defines the interface for executing SQLs in a session. +type SessionExecutor interface { + // WithNewSession executes the function with a new session. + WithNewSession(fn func(se sessionctx.Context) error) error + // WithNewTxn executes the fn in a new transaction. + WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error +} + +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) // TaskManager is the manager of global/sub task. type TaskManager struct { ctx context.Context @@ -65,9 +76,9 @@ func SetTaskManager(is *TaskManager) { taskManagerInstance.Store(is) } -// execSQL executes the sql and returns the result. +// ExecSQL executes the sql and returns the result. // TODO: consider retry. -func execSQL(ctx context.Context, se sessionctx.Context, sql string, args ...interface{}) ([]chunk.Row, error) { +func ExecSQL(ctx context.Context, se sessionctx.Context, sql string, args ...interface{}) ([]chunk.Row, error) { rs, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql, args...) if err != nil { return nil, err @@ -114,9 +125,17 @@ func (stm *TaskManager) withNewSession(fn func(se sessionctx.Context) error) err return fn(se.(sessionctx.Context)) } +<<<<<<< HEAD func (stm *TaskManager) withNewTxn(fn func(se sessionctx.Context) error) error { return stm.withNewSession(func(se sessionctx.Context) (err error) { _, err = execSQL(stm.ctx, se, "begin") +======= +// WithNewTxn executes the fn in a new transaction. +func (stm *TaskManager) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { + ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) + return stm.WithNewSession(func(se sessionctx.Context) (err error) { + _, err = ExecSQL(ctx, se, "begin") +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) if err != nil { return err } @@ -127,7 +146,7 @@ func (stm *TaskManager) withNewTxn(fn func(se sessionctx.Context) error) error { if success { sql = "commit" } - _, commitErr := execSQL(stm.ctx, se, sql) + _, commitErr := ExecSQL(ctx, se, sql) if err == nil && commitErr != nil { err = commitErr } @@ -143,8 +162,13 @@ func (stm *TaskManager) withNewTxn(fn func(se sessionctx.Context) error) error { } func (stm *TaskManager) executeSQLWithNewSession(ctx context.Context, sql string, args ...interface{}) (rs []chunk.Row, err error) { +<<<<<<< HEAD err = stm.withNewSession(func(se sessionctx.Context) error { rs, err = execSQL(ctx, se, sql, args...) +======= + err = stm.WithNewSession(func(se sessionctx.Context) error { + rs, err = ExecSQL(ctx, se, sql, args...) +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) return err }) @@ -182,6 +206,30 @@ func (stm *TaskManager) AddNewGlobalTask(key, tp string, concurrency int, meta [ return } +<<<<<<< HEAD +======= +// AddGlobalTaskWithSession adds a new task to global task table with session. +func (stm *TaskManager) AddGlobalTaskWithSession(se sessionctx.Context, key, tp string, concurrency int, meta []byte) (taskID int64, err error) { + _, err = ExecSQL(stm.ctx, se, + `insert into mysql.tidb_global_task(task_key, type, state, concurrency, step, meta, state_update_time) + values (%?, %?, %?, %?, %?, %?, %?)`, + key, tp, proto.TaskStatePending, concurrency, proto.StepInit, meta, time.Now().UTC().String()) + if err != nil { + return 0, err + } + + rs, err := ExecSQL(stm.ctx, se, "select @@last_insert_id") + if err != nil { + return 0, err + } + + taskID = int64(rs[0].GetUint64(0)) + failpoint.Inject("testSetLastTaskID", func() { TestLastTaskID.Store(taskID) }) + + return taskID, nil +} + +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) // GetNewGlobalTask get a new task from global task table, it's used by dispatcher only. func (stm *TaskManager) GetNewGlobalTask() (task *proto.Task, err error) { rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where state = %? limit 1", proto.TaskStatePending) @@ -369,9 +417,15 @@ func (stm *TaskManager) GetSchedulerIDsByTaskID(taskID int64) ([]string, error) // UpdateGlobalTaskAndAddSubTasks update the global task and add new subtasks func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask, isSubtaskRevert bool) error { +<<<<<<< HEAD return stm.withNewTxn(func(se sessionctx.Context) error { _, err := execSQL(stm.ctx, se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %?, error = %? where id = %?", gTask.State, gTask.DispatcherID, gTask.Step, gTask.StateUpdateTime.UTC().String(), gTask.Concurrency, gTask.Error, gTask.ID) +======= + return stm.WithNewTxn(stm.ctx, func(se sessionctx.Context) error { + _, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %?, meta = %?, error = %? where id = %?", + gTask.State, gTask.DispatcherID, gTask.Step, gTask.StateUpdateTime.UTC().String(), gTask.Concurrency, gTask.Meta, gTask.Error, gTask.ID) +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) if err != nil { return err } @@ -389,8 +443,13 @@ func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtas for _, subtask := range subtasks { // TODO: insert subtasks in batch +<<<<<<< HEAD _, err = execSQL(stm.ctx, se, "insert into mysql.tidb_background_subtask(task_key, exec_id, meta, state, type, checkpoint) values (%?, %?, %?, %?, %?, %?)", gTask.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{}) +======= + _, err = ExecSQL(stm.ctx, se, "insert into mysql.tidb_background_subtask(step, task_key, exec_id, meta, state, type, checkpoint) values (%?, %?, %?, %?, %?, %?, %?)", + gTask.Step, gTask.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{}) +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) if err != nil { return err } @@ -408,6 +467,16 @@ func (stm *TaskManager) CancelGlobalTask(taskID int64) error { return err } +<<<<<<< HEAD +======= +// CancelGlobalTaskByKeySession cancels global task by key using input session +func (stm *TaskManager) CancelGlobalTaskByKeySession(se sessionctx.Context, taskKey string) error { + _, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state=%? where task_key=%? and state in (%?, %?)", + proto.TaskStateCancelling, taskKey, proto.TaskStatePending, proto.TaskStateRunning) + return err +} + +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) // IsGlobalTaskCancelling checks whether the task state is cancelling func (stm *TaskManager) IsGlobalTaskCancelling(taskID int64) (bool, error) { rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_global_task where id=%? and state = %?", diff --git a/disttask/importinto/BUILD.bazel b/disttask/importinto/BUILD.bazel new file mode 100644 index 0000000000000..eabe0b7ecc10a --- /dev/null +++ b/disttask/importinto/BUILD.bazel @@ -0,0 +1,80 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "importinto", + srcs = [ + "dispatcher.go", + "job.go", + "proto.go", + "scheduler.go", + "subtask_executor.go", + "wrapper.go", + ], + importpath = "github.com/pingcap/tidb/disttask/importinto", + visibility = ["//visibility:public"], + deps = [ + "//br/pkg/lightning/backend", + "//br/pkg/lightning/backend/kv", + "//br/pkg/lightning/backend/local", + "//br/pkg/lightning/checkpoints", + "//br/pkg/lightning/common", + "//br/pkg/lightning/config", + "//br/pkg/lightning/mydump", + "//br/pkg/lightning/verification", + "//br/pkg/utils", + "//disttask/framework/dispatcher", + "//disttask/framework/handle", + "//disttask/framework/proto", + "//disttask/framework/scheduler", + "//disttask/framework/storage", + "//domain/infosync", + "//errno", + "//executor/asyncloaddata", + "//executor/importer", + "//kv", + "//parser/ast", + "//parser/mysql", + "//sessionctx", + "//sessionctx/variable", + "//table/tables", + "//util/dbterror/exeerrors", + "//util/etcd", + "//util/logutil", + "//util/mathutil", + "//util/sqlexec", + "@com_github_go_sql_driver_mysql//:mysql", + "@com_github_google_uuid//:uuid", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_tikv_client_go_v2//util", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "importinto_test", + timeout = "short", + srcs = [ + "dispatcher_test.go", + "subtask_executor_test.go", + ], + embed = [":importinto"], + flaky = True, + race = "on", + deps = [ + "//br/pkg/lightning/verification", + "//disttask/framework/proto", + "//disttask/framework/storage", + "//domain/infosync", + "//executor/importer", + "//parser/model", + "//testkit", + "//util/logutil", + "@com_github_ngaut_pools//:pools", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_stretchr_testify//require", + "@com_github_stretchr_testify//suite", + "@com_github_tikv_client_go_v2//util", + ], +) diff --git a/disttask/importinto/dispatcher.go b/disttask/importinto/dispatcher.go new file mode 100644 index 0000000000000..3f4d6822bc55d --- /dev/null +++ b/disttask/importinto/dispatcher.go @@ -0,0 +1,647 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto + +import ( + "context" + "encoding/json" + "strconv" + "strings" + "sync" + "time" + + dmysql "github.com/go-sql-driver/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/br/pkg/lightning/common" + "github.com/pingcap/tidb/br/pkg/lightning/config" + verify "github.com/pingcap/tidb/br/pkg/lightning/verification" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/disttask/framework/dispatcher" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/domain/infosync" + "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/table/tables" + "github.com/pingcap/tidb/util/etcd" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +const ( + registerTaskTTL = 10 * time.Minute + refreshTaskTTLInterval = 3 * time.Minute + registerTimeout = 5 * time.Second +) + +// NewTaskRegisterWithTTL is the ctor for TaskRegister. +// It is exported for testing. +var NewTaskRegisterWithTTL = utils.NewTaskRegisterWithTTL + +type taskInfo struct { + taskID int64 + + // operation on taskInfo is run inside detect-task goroutine, so no need to synchronize. + lastRegisterTime time.Time + + // initialized lazily in register() + etcdClient *etcd.Client + taskRegister utils.TaskRegister +} + +func (t *taskInfo) register(ctx context.Context) { + if time.Since(t.lastRegisterTime) < refreshTaskTTLInterval { + return + } + + if time.Since(t.lastRegisterTime) < refreshTaskTTLInterval { + return + } + logger := logutil.BgLogger().With(zap.Int64("task-id", t.taskID)) + if t.taskRegister == nil { + client, err := importer.GetEtcdClient() + if err != nil { + logger.Warn("get etcd client failed", zap.Error(err)) + return + } + t.etcdClient = client + t.taskRegister = NewTaskRegisterWithTTL(client.GetClient(), registerTaskTTL, + utils.RegisterImportInto, strconv.FormatInt(t.taskID, 10)) + } + timeoutCtx, cancel := context.WithTimeout(ctx, registerTimeout) + defer cancel() + if err := t.taskRegister.RegisterTaskOnce(timeoutCtx); err != nil { + logger.Warn("register task failed", zap.Error(err)) + } else { + logger.Info("register task to pd or refresh lease success") + } + // we set it even if register failed, TTL is 10min, refresh interval is 3min, + // we can try 2 times before the lease is expired. + t.lastRegisterTime = time.Now() +} + +func (t *taskInfo) close(ctx context.Context) { + logger := logutil.BgLogger().With(zap.Int64("task-id", t.taskID)) + if t.taskRegister != nil { + timeoutCtx, cancel := context.WithTimeout(ctx, registerTimeout) + defer cancel() + if err := t.taskRegister.Close(timeoutCtx); err != nil { + logger.Warn("unregister task failed", zap.Error(err)) + } else { + logger.Info("unregister task success") + } + t.taskRegister = nil + } + if t.etcdClient != nil { + if err := t.etcdClient.Close(); err != nil { + logger.Warn("close etcd client failed", zap.Error(err)) + } + t.etcdClient = nil + } +} + +type flowHandle struct { + mu sync.RWMutex + // NOTE: there's no need to sync for below 2 fields actually, since we add a restriction that only one + // task can be running at a time. but we might support task queuing in the future, leave it for now. + // the last time we switch TiKV into IMPORT mode, this is a global operation, do it for one task makes + // no difference to do it for all tasks. So we do not need to record the switch time for each task. + lastSwitchTime atomic.Time + // taskInfoMap is a map from taskID to taskInfo + taskInfoMap sync.Map + + // currTaskID is the taskID of the current running task. + // It may be changed when we switch to a new task or switch to a new owner. + currTaskID atomic.Int64 + disableTiKVImportMode atomic.Bool +} + +var _ dispatcher.TaskFlowHandle = (*flowHandle)(nil) + +func (h *flowHandle) OnTicker(ctx context.Context, task *proto.Task) { + // only switch TiKV mode or register task when task is running + if task.State != proto.TaskStateRunning { + return + } + h.switchTiKVMode(ctx, task) + h.registerTask(ctx, task) +} + +func (h *flowHandle) switchTiKVMode(ctx context.Context, task *proto.Task) { + h.updateCurrentTask(task) + // only import step need to switch to IMPORT mode, + // If TiKV is in IMPORT mode during checksum, coprocessor will time out. + if h.disableTiKVImportMode.Load() || task.Step != StepImport { + return + } + + if time.Since(h.lastSwitchTime.Load()) < config.DefaultSwitchTiKVModeInterval { + return + } + + h.mu.Lock() + defer h.mu.Unlock() + if time.Since(h.lastSwitchTime.Load()) < config.DefaultSwitchTiKVModeInterval { + return + } + + logger := logutil.BgLogger().With(zap.Int64("task-id", task.ID)) + switcher, err := importer.GetTiKVModeSwitcher(logger) + if err != nil { + logger.Warn("get tikv mode switcher failed", zap.Error(err)) + return + } + switcher.ToImportMode(ctx) + h.lastSwitchTime.Store(time.Now()) +} + +func (h *flowHandle) registerTask(ctx context.Context, task *proto.Task) { + val, _ := h.taskInfoMap.LoadOrStore(task.ID, &taskInfo{taskID: task.ID}) + info := val.(*taskInfo) + info.register(ctx) +} + +func (h *flowHandle) unregisterTask(ctx context.Context, task *proto.Task) { + if val, loaded := h.taskInfoMap.LoadAndDelete(task.ID); loaded { + info := val.(*taskInfo) + info.close(ctx) + } +} + +func (h *flowHandle) ProcessNormalFlow(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task) ( + resSubtaskMeta [][]byte, err error) { + logger := logutil.BgLogger().With( + zap.String("type", gTask.Type), + zap.Int64("task-id", gTask.ID), + zap.String("step", stepStr(gTask.Step)), + ) + taskMeta := &TaskMeta{} + err = json.Unmarshal(gTask.Meta, taskMeta) + if err != nil { + return nil, err + } + logger.Info("process normal flow") + + defer func() { + // currently, framework will take the task as finished when err is not nil or resSubtaskMeta is empty. + taskFinished := err == nil && len(resSubtaskMeta) == 0 + if taskFinished { + // todo: we're not running in a transaction with task update + if err2 := h.finishJob(ctx, handle, gTask, taskMeta); err2 != nil { + err = err2 + } + } else if err != nil && !h.IsRetryableErr(err) { + if err2 := h.failJob(ctx, handle, gTask, taskMeta, logger, err.Error()); err2 != nil { + // todo: we're not running in a transaction with task update, there might be case + // failJob return error, but task update succeed. + logger.Error("call failJob failed", zap.Error(err2)) + } + } + }() + + switch gTask.Step { + case proto.StepInit: + if err := preProcess(ctx, handle, gTask, taskMeta, logger); err != nil { + return nil, err + } + if err = startJob(ctx, handle, taskMeta); err != nil { + return nil, err + } + subtaskMetas, err := generateImportStepMetas(ctx, taskMeta) + if err != nil { + return nil, err + } + logger.Info("move to import step", zap.Any("subtask-count", len(subtaskMetas))) + metaBytes := make([][]byte, 0, len(subtaskMetas)) + for _, subtaskMeta := range subtaskMetas { + bs, err := json.Marshal(subtaskMeta) + if err != nil { + return nil, err + } + metaBytes = append(metaBytes, bs) + } + gTask.Step = StepImport + return metaBytes, nil + case StepImport: + h.switchTiKV2NormalMode(ctx, gTask, logger) + failpoint.Inject("clearLastSwitchTime", func() { + h.lastSwitchTime.Store(time.Time{}) + }) + stepMeta, err2 := toPostProcessStep(handle, gTask, taskMeta) + if err2 != nil { + return nil, err2 + } + if err = job2Step(ctx, taskMeta, importer.JobStepValidating); err != nil { + return nil, err + } + logger.Info("move to post-process step ", zap.Any("result", taskMeta.Result), + zap.Any("step-meta", stepMeta)) + bs, err := json.Marshal(stepMeta) + if err != nil { + return nil, err + } + failpoint.Inject("failWhenDispatchPostProcessSubtask", func() { + failpoint.Return(nil, errors.New("injected error after StepImport")) + }) + gTask.Step = StepPostProcess + return [][]byte{bs}, nil + case StepPostProcess: + return nil, nil + default: + return nil, errors.Errorf("unknown step %d", gTask.Step) + } +} + +func (h *flowHandle) ProcessErrFlow(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, receiveErr [][]byte) ([]byte, error) { + logger := logutil.BgLogger().With( + zap.String("type", gTask.Type), + zap.Int64("task-id", gTask.ID), + zap.String("step", stepStr(gTask.Step)), + ) + logger.Info("process error flow", zap.ByteStrings("error-message", receiveErr)) + taskMeta := &TaskMeta{} + err := json.Unmarshal(gTask.Meta, taskMeta) + if err != nil { + return nil, err + } + errStrs := make([]string, 0, len(receiveErr)) + for _, errStr := range receiveErr { + errStrs = append(errStrs, string(errStr)) + } + if err = h.failJob(ctx, handle, gTask, taskMeta, logger, strings.Join(errStrs, "; ")); err != nil { + return nil, err + } + + gTask.Error = receiveErr[0] + + errStr := string(receiveErr[0]) + // do nothing if the error is resumable + if isResumableErr(errStr) { + return nil, nil + } + + if gTask.Step == StepImport { + err = rollback(ctx, handle, gTask, logger) + if err != nil { + // TODO: add error code according to spec. + gTask.Error = []byte(errStr + ", " + err.Error()) + } + } + return nil, err +} + +func (*flowHandle) GetEligibleInstances(ctx context.Context, gTask *proto.Task) ([]*infosync.ServerInfo, error) { + taskMeta := &TaskMeta{} + err := json.Unmarshal(gTask.Meta, taskMeta) + if err != nil { + return nil, err + } + if len(taskMeta.EligibleInstances) > 0 { + return taskMeta.EligibleInstances, nil + } + return dispatcher.GenerateSchedulerNodes(ctx) +} + +func (*flowHandle) IsRetryableErr(error) bool { + // TODO: check whether the error is retryable. + return false +} + +func (h *flowHandle) switchTiKV2NormalMode(ctx context.Context, task *proto.Task, logger *zap.Logger) { + h.updateCurrentTask(task) + if h.disableTiKVImportMode.Load() { + return + } + + h.mu.Lock() + defer h.mu.Unlock() + + switcher, err := importer.GetTiKVModeSwitcher(logger) + if err != nil { + logger.Warn("get tikv mode switcher failed", zap.Error(err)) + return + } + switcher.ToNormalMode(ctx) + + // clear it, so next task can switch TiKV mode again. + h.lastSwitchTime.Store(time.Time{}) +} + +func (h *flowHandle) updateCurrentTask(task *proto.Task) { + if h.currTaskID.Swap(task.ID) != task.ID { + taskMeta := &TaskMeta{} + if err := json.Unmarshal(task.Meta, taskMeta); err == nil { + h.disableTiKVImportMode.Store(taskMeta.Plan.DisableTiKVImportMode) + } + } +} + +// preProcess does the pre-processing for the task. +func preProcess(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta, logger *zap.Logger) error { + logger.Info("pre process") + // TODO: drop table indexes depends on the option. + // if err := dropTableIndexes(ctx, handle, taskMeta, logger); err != nil { + // return err + // } + return updateMeta(gTask, taskMeta) +} + +// nolint:deadcode +func dropTableIndexes(ctx context.Context, handle dispatcher.TaskHandle, taskMeta *TaskMeta, logger *zap.Logger) error { + tblInfo := taskMeta.Plan.TableInfo + tableName := common.UniqueTable(taskMeta.Plan.DBName, tblInfo.Name.L) + + remainIndexes, dropIndexes := common.GetDropIndexInfos(tblInfo) + for _, idxInfo := range dropIndexes { + sqlStr := common.BuildDropIndexSQL(tableName, idxInfo) + if err := executeSQL(ctx, handle, logger, sqlStr); err != nil { + if merr, ok := errors.Cause(err).(*dmysql.MySQLError); ok { + switch merr.Number { + case errno.ErrCantDropFieldOrKey, errno.ErrDropIndexNeededInForeignKey: + remainIndexes = append(remainIndexes, idxInfo) + logger.Warn("can't drop index, skip", zap.String("index", idxInfo.Name.O), zap.Error(err)) + continue + } + } + return err + } + } + if len(remainIndexes) < len(tblInfo.Indices) { + taskMeta.Plan.TableInfo = taskMeta.Plan.TableInfo.Clone() + taskMeta.Plan.TableInfo.Indices = remainIndexes + } + return nil +} + +// nolint:deadcode +func createTableIndexes(ctx context.Context, executor storage.SessionExecutor, taskMeta *TaskMeta, logger *zap.Logger) error { + tableName := common.UniqueTable(taskMeta.Plan.DBName, taskMeta.Plan.TableInfo.Name.L) + singleSQL, multiSQLs := common.BuildAddIndexSQL(tableName, taskMeta.Plan.TableInfo, taskMeta.Plan.DesiredTableInfo) + logger.Info("build add index sql", zap.String("singleSQL", singleSQL), zap.Strings("multiSQLs", multiSQLs)) + if len(multiSQLs) == 0 { + return nil + } + + err := executeSQL(ctx, executor, logger, singleSQL) + if err == nil { + return nil + } + if !common.IsDupKeyError(err) { + // TODO: refine err msg and error code according to spec. + return errors.Errorf("Failed to create index: %v, please execute the SQL manually, sql: %s", err, singleSQL) + } + if len(multiSQLs) == 1 { + return nil + } + logger.Warn("cannot add all indexes in one statement, try to add them one by one", zap.Strings("sqls", multiSQLs), zap.Error(err)) + + for i, ddl := range multiSQLs { + err := executeSQL(ctx, executor, logger, ddl) + if err != nil && !common.IsDupKeyError(err) { + // TODO: refine err msg and error code according to spec. + return errors.Errorf("Failed to create index: %v, please execute the SQLs manually, sqls: %s", err, strings.Join(multiSQLs[i:], ";")) + } + } + return nil +} + +// TODO: return the result of sql. +func executeSQL(ctx context.Context, executor storage.SessionExecutor, logger *zap.Logger, sql string, args ...interface{}) (err error) { + logger.Info("execute sql", zap.String("sql", sql), zap.Any("args", args)) + return executor.WithNewSession(func(se sessionctx.Context) error { + _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql, args...) + return err + }) +} + +func updateMeta(gTask *proto.Task, taskMeta *TaskMeta) error { + bs, err := json.Marshal(taskMeta) + if err != nil { + return err + } + gTask.Meta = bs + return nil +} + +func buildController(taskMeta *TaskMeta) (*importer.LoadDataController, error) { + idAlloc := kv.NewPanickingAllocators(0) + tbl, err := tables.TableFromMeta(idAlloc, taskMeta.Plan.TableInfo) + if err != nil { + return nil, err + } + + astArgs, err := importer.ASTArgsFromStmt(taskMeta.Stmt) + if err != nil { + return nil, err + } + controller, err := importer.NewLoadDataController(&taskMeta.Plan, tbl, astArgs) + if err != nil { + return nil, err + } + return controller, nil +} + +// todo: converting back and forth, we should unify struct and remove this function later. +func toChunkMap(engineCheckpoints map[int32]*checkpoints.EngineCheckpoint) map[int32][]Chunk { + chunkMap := make(map[int32][]Chunk, len(engineCheckpoints)) + for id, ecp := range engineCheckpoints { + chunkMap[id] = make([]Chunk, 0, len(ecp.Chunks)) + for _, chunkCheckpoint := range ecp.Chunks { + chunkMap[id] = append(chunkMap[id], toChunk(*chunkCheckpoint)) + } + } + return chunkMap +} + +func generateImportStepMetas(ctx context.Context, taskMeta *TaskMeta) (subtaskMetas []*ImportStepMeta, err error) { + var chunkMap map[int32][]Chunk + if len(taskMeta.ChunkMap) > 0 { + chunkMap = taskMeta.ChunkMap + } else { + controller, err2 := buildController(taskMeta) + if err2 != nil { + return nil, err2 + } + if err2 = controller.InitDataFiles(ctx); err2 != nil { + return nil, err2 + } + + engineCheckpoints, err2 := controller.PopulateChunks(ctx) + if err2 != nil { + return nil, err2 + } + chunkMap = toChunkMap(engineCheckpoints) + } + for id := range chunkMap { + if id == common.IndexEngineID { + continue + } + subtaskMeta := &ImportStepMeta{ + ID: id, + Chunks: chunkMap[id], + } + subtaskMetas = append(subtaskMetas, subtaskMeta) + } + return subtaskMetas, nil +} + +// we will update taskMeta in place and make gTask.Meta point to the new taskMeta. +func toPostProcessStep(handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) (*PostProcessStepMeta, error) { + metas, err := handle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step) + if err != nil { + return nil, err + } + + subtaskMetas := make([]*ImportStepMeta, 0, len(metas)) + for _, bs := range metas { + var subtaskMeta ImportStepMeta + if err := json.Unmarshal(bs, &subtaskMeta); err != nil { + return nil, err + } + subtaskMetas = append(subtaskMetas, &subtaskMeta) + } + var localChecksum verify.KVChecksum + columnSizeMap := make(map[int64]int64) + for _, subtaskMeta := range subtaskMetas { + checksum := verify.MakeKVChecksum(subtaskMeta.Checksum.Size, subtaskMeta.Checksum.KVs, subtaskMeta.Checksum.Sum) + localChecksum.Add(&checksum) + + taskMeta.Result.ReadRowCnt += subtaskMeta.Result.ReadRowCnt + taskMeta.Result.LoadedRowCnt += subtaskMeta.Result.LoadedRowCnt + for key, val := range subtaskMeta.Result.ColSizeMap { + columnSizeMap[key] += val + } + } + taskMeta.Result.ColSizeMap = columnSizeMap + if err2 := updateMeta(gTask, taskMeta); err2 != nil { + return nil, err2 + } + return &PostProcessStepMeta{ + Checksum: Checksum{ + Size: localChecksum.SumSize(), + KVs: localChecksum.SumKVS(), + Sum: localChecksum.Sum(), + }, + }, nil +} + +func startJob(ctx context.Context, handle dispatcher.TaskHandle, taskMeta *TaskMeta) error { + failpoint.Inject("syncBeforeJobStarted", func() { + TestSyncChan <- struct{}{} + <-TestSyncChan + }) + err := handle.WithNewSession(func(se sessionctx.Context) error { + exec := se.(sqlexec.SQLExecutor) + return importer.StartJob(ctx, exec, taskMeta.JobID) + }) + failpoint.Inject("syncAfterJobStarted", func() { + TestSyncChan <- struct{}{} + }) + return err +} + +func job2Step(ctx context.Context, taskMeta *TaskMeta, step string) error { + globalTaskManager, err := storage.GetTaskManager() + if err != nil { + return err + } + // todo: use dispatcher.TaskHandle + // we might call this in scheduler later, there's no dispatcher.TaskHandle, so we use globalTaskManager here. + return globalTaskManager.WithNewSession(func(se sessionctx.Context) error { + exec := se.(sqlexec.SQLExecutor) + return importer.Job2Step(ctx, exec, taskMeta.JobID, step) + }) +} + +func (h *flowHandle) finishJob(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) error { + h.unregisterTask(ctx, gTask) + redactSensitiveInfo(gTask, taskMeta) + summary := &importer.JobSummary{ImportedRows: taskMeta.Result.LoadedRowCnt} + return handle.WithNewSession(func(se sessionctx.Context) error { + exec := se.(sqlexec.SQLExecutor) + return importer.FinishJob(ctx, exec, taskMeta.JobID, summary) + }) +} + +func (h *flowHandle) failJob(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, + taskMeta *TaskMeta, logger *zap.Logger, errorMsg string) error { + h.switchTiKV2NormalMode(ctx, gTask, logger) + h.unregisterTask(ctx, gTask) + redactSensitiveInfo(gTask, taskMeta) + return handle.WithNewSession(func(se sessionctx.Context) error { + exec := se.(sqlexec.SQLExecutor) + return importer.FailJob(ctx, exec, taskMeta.JobID, errorMsg) + }) +} + +func redactSensitiveInfo(gTask *proto.Task, taskMeta *TaskMeta) { + taskMeta.Stmt = "" + taskMeta.Plan.Path = ast.RedactURL(taskMeta.Plan.Path) + if err := updateMeta(gTask, taskMeta); err != nil { + // marshal failed, should not happen + logutil.BgLogger().Warn("failed to update task meta", zap.Error(err)) + } +} + +// isResumableErr checks whether it's possible to rely on checkpoint to re-import data after the error has been fixed. +func isResumableErr(string) bool { + // TODO: add more cases + return false +} + +func rollback(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, logger *zap.Logger) (err error) { + taskMeta := &TaskMeta{} + err = json.Unmarshal(gTask.Meta, taskMeta) + if err != nil { + return err + } + + logger.Info("rollback") + + // // TODO: create table indexes depends on the option. + // // create table indexes even if the rollback is failed. + // defer func() { + // err2 := createTableIndexes(ctx, handle, taskMeta, logger) + // err = multierr.Append(err, err2) + // }() + + tableName := common.UniqueTable(taskMeta.Plan.DBName, taskMeta.Plan.TableInfo.Name.L) + // truncate the table + return executeSQL(ctx, handle, logger, "TRUNCATE "+tableName) +} + +func stepStr(step int64) string { + switch step { + case proto.StepInit: + return "init" + case StepImport: + return "import" + case StepPostProcess: + return "postprocess" + default: + return "unknown" + } +} + +func init() { + dispatcher.RegisterTaskFlowHandle(proto.ImportInto, &flowHandle{}) +} diff --git a/disttask/importinto/job.go b/disttask/importinto/job.go new file mode 100644 index 0000000000000..64b61048d8c88 --- /dev/null +++ b/disttask/importinto/job.go @@ -0,0 +1,279 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/disttask/framework/handle" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/domain/infosync" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/dbterror/exeerrors" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" + "go.uber.org/zap" +) + +// DistImporter is a JobImporter for distributed IMPORT INTO. +type DistImporter struct { + *importer.JobImportParam + plan *importer.Plan + stmt string + logger *zap.Logger + // the instance to import data, used for single-node import, nil means import data on all instances. + instance *infosync.ServerInfo + // the files to import, when import from server file, we need to pass those file to the framework. + chunkMap map[int32][]Chunk + sourceFileSize int64 + // only set after submit task + jobID int64 + taskID int64 +} + +// NewDistImporter creates a new DistImporter. +func NewDistImporter(param *importer.JobImportParam, plan *importer.Plan, stmt string, sourceFileSize int64) (*DistImporter, error) { + return &DistImporter{ + JobImportParam: param, + plan: plan, + stmt: stmt, + logger: logutil.BgLogger(), + sourceFileSize: sourceFileSize, + }, nil +} + +// NewDistImporterCurrNode creates a new DistImporter to import data on current node. +func NewDistImporterCurrNode(param *importer.JobImportParam, plan *importer.Plan, stmt string, sourceFileSize int64) (*DistImporter, error) { + serverInfo, err := infosync.GetServerInfo() + if err != nil { + return nil, err + } + return &DistImporter{ + JobImportParam: param, + plan: plan, + stmt: stmt, + logger: logutil.BgLogger(), + instance: serverInfo, + sourceFileSize: sourceFileSize, + }, nil +} + +// NewDistImporterServerFile creates a new DistImporter to import given files on current node. +// we also run import on current node. +// todo: merge all 3 ctor into one. +func NewDistImporterServerFile(param *importer.JobImportParam, plan *importer.Plan, stmt string, ecp map[int32]*checkpoints.EngineCheckpoint, sourceFileSize int64) (*DistImporter, error) { + distImporter, err := NewDistImporterCurrNode(param, plan, stmt, sourceFileSize) + if err != nil { + return nil, err + } + distImporter.chunkMap = toChunkMap(ecp) + return distImporter, nil +} + +// Param implements JobImporter.Param. +func (ti *DistImporter) Param() *importer.JobImportParam { + return ti.JobImportParam +} + +// Import implements JobImporter.Import. +func (*DistImporter) Import() { + // todo: remove it +} + +// ImportTask import task. +func (ti *DistImporter) ImportTask(task *proto.Task) { + ti.logger.Info("start distribute IMPORT INTO") + ti.Group.Go(func() error { + defer close(ti.Done) + // task is run using distribute framework, so we only wait for the task to finish. + return handle.WaitGlobalTask(ti.GroupCtx, task) + }) +} + +// Result implements JobImporter.Result. +func (ti *DistImporter) Result() importer.JobImportResult { + var result importer.JobImportResult + taskMeta, err := getTaskMeta(ti.jobID) + if err != nil { + result.Msg = err.Error() + return result + } + + var ( + numWarnings uint64 + numRecords uint64 + numDeletes uint64 + numSkipped uint64 + ) + numRecords = taskMeta.Result.ReadRowCnt + // todo: we don't have a strict REPLACE or IGNORE mode in physical mode, so we can't get the numDeletes/numSkipped. + // we can have it when there's duplicate detection. + msg := fmt.Sprintf(mysql.MySQLErrName[mysql.ErrLoadInfo].Raw, numRecords, numDeletes, numSkipped, numWarnings) + return importer.JobImportResult{ + Msg: msg, + Affected: taskMeta.Result.ReadRowCnt, + ColSizeMap: taskMeta.Result.ColSizeMap, + } +} + +// Close implements the io.Closer interface. +func (*DistImporter) Close() error { + return nil +} + +// SubmitTask submits a task to the distribute framework. +func (ti *DistImporter) SubmitTask(ctx context.Context) (int64, *proto.Task, error) { + var instances []*infosync.ServerInfo + if ti.instance != nil { + instances = append(instances, ti.instance) + } + // we use globalTaskManager to submit task, user might not have the privilege to system tables. + globalTaskManager, err := storage.GetTaskManager() + if err != nil { + return 0, nil, err + } + + var jobID, taskID int64 + plan := ti.plan + if err = globalTaskManager.WithNewTxn(ctx, func(se sessionctx.Context) error { + var err2 error + exec := se.(sqlexec.SQLExecutor) + // If 2 client try to execute IMPORT INTO concurrently, there's chance that both of them will pass the check. + // We can enforce ONLY one import job running by: + // - using LOCK TABLES, but it requires enable-table-lock=true, it's not enabled by default. + // - add a key to PD as a distributed lock, but it's a little complex, and we might support job queuing later. + // So we only add this simple soft check here and doc it. + activeJobCnt, err2 := importer.GetActiveJobCnt(ctx, exec) + if err2 != nil { + return err2 + } + if activeJobCnt > 0 { + return exeerrors.ErrLoadDataPreCheckFailed.FastGenByArgs("there's pending or running jobs") + } + jobID, err2 = importer.CreateJob(ctx, exec, plan.DBName, plan.TableInfo.Name.L, plan.TableInfo.ID, + plan.User, plan.Parameters, ti.sourceFileSize) + if err2 != nil { + return err2 + } + task := TaskMeta{ + JobID: jobID, + Plan: *plan, + Stmt: ti.stmt, + EligibleInstances: instances, + ChunkMap: ti.chunkMap, + } + taskMeta, err2 := json.Marshal(task) + if err2 != nil { + return err2 + } + taskID, err2 = globalTaskManager.AddGlobalTaskWithSession(se, TaskKey(jobID), proto.ImportInto, + int(plan.ThreadCnt), taskMeta) + if err2 != nil { + return err2 + } + return nil + }); err != nil { + return 0, nil, err + } + + globalTask, err := globalTaskManager.GetGlobalTaskByID(taskID) + if err != nil { + return 0, nil, err + } + if globalTask == nil { + return 0, nil, errors.Errorf("cannot find global task with ID %d", taskID) + } + // update logger with task id. + ti.jobID = jobID + ti.taskID = taskID + ti.logger = ti.logger.With(zap.Int64("task-id", globalTask.ID)) + + ti.logger.Info("job submitted to global task queue", zap.Int64("job-id", jobID)) + + return jobID, globalTask, nil +} + +func (*DistImporter) taskKey() string { + // task key is meaningless to IMPORT INTO, so we use a random uuid. + return fmt.Sprintf("%s/%s", proto.ImportInto, uuid.New().String()) +} + +// JobID returns the job id. +func (ti *DistImporter) JobID() int64 { + return ti.jobID +} + +func getTaskMeta(jobID int64) (*TaskMeta, error) { + globalTaskManager, err := storage.GetTaskManager() + if err != nil { + return nil, err + } + taskKey := TaskKey(jobID) + globalTask, err := globalTaskManager.GetGlobalTaskByKey(taskKey) + if err != nil { + return nil, err + } + if globalTask == nil { + return nil, errors.Errorf("cannot find global task with key %s", taskKey) + } + var taskMeta TaskMeta + if err := json.Unmarshal(globalTask.Meta, &taskMeta); err != nil { + return nil, err + } + return &taskMeta, nil +} + +// GetTaskImportedRows gets the number of imported rows of a job. +// Note: for finished job, we can get the number of imported rows from task meta. +func GetTaskImportedRows(jobID int64) (uint64, error) { + globalTaskManager, err := storage.GetTaskManager() + if err != nil { + return 0, err + } + taskKey := TaskKey(jobID) + globalTask, err := globalTaskManager.GetGlobalTaskByKey(taskKey) + if err != nil { + return 0, err + } + if globalTask == nil { + return 0, errors.Errorf("cannot find global task with key %s", taskKey) + } + subtasks, err := globalTaskManager.GetSubtasksByStep(globalTask.ID, StepImport) + if err != nil { + return 0, err + } + var importedRows uint64 + for _, subtask := range subtasks { + var subtaskMeta ImportStepMeta + if err2 := json.Unmarshal(subtask.Meta, &subtaskMeta); err2 != nil { + return 0, err2 + } + importedRows += subtaskMeta.Result.LoadedRowCnt + } + return importedRows, nil +} + +// TaskKey returns the task key for a job. +func TaskKey(jobID int64) string { + return fmt.Sprintf("%s/%d", proto.ImportInto, jobID) +} diff --git a/disttask/importinto/subtask_executor.go b/disttask/importinto/subtask_executor.go new file mode 100644 index 0000000000000..be6de9a75d0c0 --- /dev/null +++ b/disttask/importinto/subtask_executor.go @@ -0,0 +1,240 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto + +import ( + "context" + "strconv" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/lightning/backend/local" + "github.com/pingcap/tidb/br/pkg/lightning/common" + "github.com/pingcap/tidb/br/pkg/lightning/config" + verify "github.com/pingcap/tidb/br/pkg/lightning/verification" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/disttask/framework/scheduler" + "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/mathutil" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +// TestSyncChan is used to test. +var TestSyncChan = make(chan struct{}) + +// ImportMinimalTaskExecutor is a minimal task executor for IMPORT INTO. +type ImportMinimalTaskExecutor struct { + mTtask *importStepMinimalTask +} + +// Run implements the SubtaskExecutor.Run interface. +func (e *ImportMinimalTaskExecutor) Run(ctx context.Context) error { + logger := logutil.BgLogger().With(zap.String("type", proto.ImportInto), zap.Int64("table-id", e.mTtask.Plan.TableInfo.ID)) + logger.Info("run minimal task") + failpoint.Inject("waitBeforeSortChunk", func() { + time.Sleep(3 * time.Second) + }) + failpoint.Inject("errorWhenSortChunk", func() { + failpoint.Return(errors.New("occur an error when sort chunk")) + }) + failpoint.Inject("syncBeforeSortChunk", func() { + TestSyncChan <- struct{}{} + <-TestSyncChan + }) + chunkCheckpoint := toChunkCheckpoint(e.mTtask.Chunk) + sharedVars := e.mTtask.SharedVars + if err := importer.ProcessChunk(ctx, &chunkCheckpoint, sharedVars.TableImporter, sharedVars.DataEngine, sharedVars.IndexEngine, sharedVars.Progress, logger); err != nil { + return err + } + + sharedVars.mu.Lock() + defer sharedVars.mu.Unlock() + sharedVars.Checksum.Add(&chunkCheckpoint.Checksum) + return nil +} + +type postProcessMinimalTaskExecutor struct { + mTask *postProcessStepMinimalTask +} + +func (e *postProcessMinimalTaskExecutor) Run(ctx context.Context) error { + mTask := e.mTask + failpoint.Inject("waitBeforePostProcess", func() { + time.Sleep(5 * time.Second) + }) + return postProcess(ctx, mTask.taskMeta, &mTask.meta, mTask.logger) +} + +// postProcess does the post-processing for the task. +func postProcess(ctx context.Context, taskMeta *TaskMeta, subtaskMeta *PostProcessStepMeta, logger *zap.Logger) (err error) { + failpoint.Inject("syncBeforePostProcess", func() { + TestSyncChan <- struct{}{} + <-TestSyncChan + }) + + logger.Info("post process") + + // TODO: create table indexes depends on the option. + // create table indexes even if the post process is failed. + // defer func() { + // err2 := createTableIndexes(ctx, globalTaskManager, taskMeta, logger) + // err = multierr.Append(err, err2) + // }() + + return verifyChecksum(ctx, taskMeta, subtaskMeta, logger) +} + +func verifyChecksum(ctx context.Context, taskMeta *TaskMeta, subtaskMeta *PostProcessStepMeta, logger *zap.Logger) error { + if taskMeta.Plan.Checksum == config.OpLevelOff { + return nil + } + localChecksum := verify.MakeKVChecksum(subtaskMeta.Checksum.Size, subtaskMeta.Checksum.KVs, subtaskMeta.Checksum.Sum) + logger.Info("local checksum", zap.Object("checksum", &localChecksum)) + + failpoint.Inject("waitCtxDone", func() { + <-ctx.Done() + }) + + globalTaskManager, err := storage.GetTaskManager() + if err != nil { + return err + } + remoteChecksum, err := checksumTable(ctx, globalTaskManager, taskMeta, logger) + if err != nil { + return err + } + if !remoteChecksum.IsEqual(&localChecksum) { + err2 := common.ErrChecksumMismatch.GenWithStackByArgs( + remoteChecksum.Checksum, localChecksum.Sum(), + remoteChecksum.TotalKVs, localChecksum.SumKVS(), + remoteChecksum.TotalBytes, localChecksum.SumSize(), + ) + if taskMeta.Plan.Checksum == config.OpLevelOptional { + logger.Warn("verify checksum failed, but checksum is optional, will skip it", zap.Error(err2)) + err2 = nil + } + return err2 + } + logger.Info("checksum pass", zap.Object("local", &localChecksum)) + return nil +} + +func checksumTable(ctx context.Context, executor storage.SessionExecutor, taskMeta *TaskMeta, logger *zap.Logger) (*local.RemoteChecksum, error) { + var ( + tableName = common.UniqueTable(taskMeta.Plan.DBName, taskMeta.Plan.TableInfo.Name.L) + sql = "ADMIN CHECKSUM TABLE " + tableName + maxErrorRetryCount = 3 + distSQLScanConcurrencyFactor = 1 + remoteChecksum *local.RemoteChecksum + txnErr error + ) + + ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) + for i := 0; i < maxErrorRetryCount; i++ { + txnErr = executor.WithNewTxn(ctx, func(se sessionctx.Context) error { + // increase backoff weight + if err := setBackoffWeight(se, taskMeta, logger); err != nil { + logger.Warn("set tidb_backoff_weight failed", zap.Error(err)) + } + + distSQLScanConcurrency := se.GetSessionVars().DistSQLScanConcurrency() + se.GetSessionVars().SetDistSQLScanConcurrency(mathutil.Max(distSQLScanConcurrency/distSQLScanConcurrencyFactor, local.MinDistSQLScanConcurrency)) + defer func() { + se.GetSessionVars().SetDistSQLScanConcurrency(distSQLScanConcurrency) + }() + + rs, err := storage.ExecSQL(ctx, se, sql) + if err != nil { + return err + } + if len(rs) < 1 { + return errors.New("empty checksum result") + } + + failpoint.Inject("errWhenChecksum", func() { + if i == 0 { + failpoint.Return(errors.New("occur an error when checksum, coprocessor task terminated due to exceeding the deadline")) + } + }) + + // ADMIN CHECKSUM TABLE .
example. + // mysql> admin checksum table test.t; + // +---------+------------+---------------------+-----------+-------------+ + // | Db_name | Table_name | Checksum_crc64_xor | Total_kvs | Total_bytes | + // +---------+------------+---------------------+-----------+-------------+ + // | test | t | 8520875019404689597 | 7296873 | 357601387 | + // +---------+------------+------------- + remoteChecksum = &local.RemoteChecksum{ + Schema: rs[0].GetString(0), + Table: rs[0].GetString(1), + Checksum: rs[0].GetUint64(2), + TotalKVs: rs[0].GetUint64(3), + TotalBytes: rs[0].GetUint64(4), + } + return nil + }) + if !common.IsRetryableError(txnErr) { + break + } + distSQLScanConcurrencyFactor *= 2 + logger.Warn("retry checksum table", zap.Int("retry count", i+1), zap.Error(txnErr)) + } + return remoteChecksum, txnErr +} + +// TestChecksumTable is used to test checksum table in unit test. +func TestChecksumTable(ctx context.Context, executor storage.SessionExecutor, taskMeta *TaskMeta, logger *zap.Logger) (*local.RemoteChecksum, error) { + return checksumTable(ctx, executor, taskMeta, logger) +} + +func setBackoffWeight(se sessionctx.Context, taskMeta *TaskMeta, logger *zap.Logger) error { + backoffWeight := local.DefaultBackoffWeight + if val, ok := taskMeta.Plan.ImportantSysVars[variable.TiDBBackOffWeight]; ok { + if weight, err := strconv.Atoi(val); err == nil && weight > backoffWeight { + backoffWeight = weight + } + } + logger.Info("set backoff weight", zap.Int("weight", backoffWeight)) + return se.GetSessionVars().SetSystemVar(variable.TiDBBackOffWeight, strconv.Itoa(backoffWeight)) +} + +func init() { + scheduler.RegisterSubtaskExectorConstructor(proto.ImportInto, StepImport, + // The order of the subtask executors is the same as the order of the subtasks. + func(minimalTask proto.MinimalTask, step int64) (scheduler.SubtaskExecutor, error) { + task, ok := minimalTask.(*importStepMinimalTask) + if !ok { + return nil, errors.Errorf("invalid task type %T", minimalTask) + } + return &ImportMinimalTaskExecutor{mTtask: task}, nil + }, + ) + scheduler.RegisterSubtaskExectorConstructor(proto.ImportInto, StepPostProcess, + func(minimalTask proto.MinimalTask, step int64) (scheduler.SubtaskExecutor, error) { + mTask, ok := minimalTask.(*postProcessStepMinimalTask) + if !ok { + return nil, errors.Errorf("invalid task type %T", minimalTask) + } + return &postProcessMinimalTaskExecutor{mTask: mTask}, nil + }, + ) +} diff --git a/disttask/loaddata/subtask_executor_test.go b/disttask/loaddata/subtask_executor_test.go new file mode 100644 index 0000000000000..4596ffc795aa2 --- /dev/null +++ b/disttask/loaddata/subtask_executor_test.go @@ -0,0 +1,73 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto_test + +import ( + "context" + "testing" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/failpoint" + verify "github.com/pingcap/tidb/br/pkg/lightning/verification" + "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/disttask/importinto" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/util/logutil" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/util" +) + +func TestChecksumTable(t *testing.T) { + ctx := context.Background() + store := testkit.CreateMockStore(t) + gtk := testkit.NewTestKit(t, store) + pool := pools.NewResourcePool(func() (pools.Resource, error) { + return gtk.Session(), nil + }, 1, 1, time.Second) + defer pool.Close() + mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), pool) + + taskMeta := &importinto.TaskMeta{ + Plan: importer.Plan{ + DBName: "db", + TableInfo: &model.TableInfo{ + Name: model.NewCIStr("tb"), + }, + }, + } + // fake result + localChecksum := verify.MakeKVChecksum(1, 1, 1) + gtk.MustExec("create database db") + gtk.MustExec("create table db.tb(id int)") + gtk.MustExec("insert into db.tb values(1)") + remoteChecksum, err := importinto.TestChecksumTable(ctx, mgr, taskMeta, logutil.BgLogger()) + require.NoError(t, err) + require.True(t, remoteChecksum.IsEqual(&localChecksum)) + // again + remoteChecksum, err = importinto.TestChecksumTable(ctx, mgr, taskMeta, logutil.BgLogger()) + require.NoError(t, err) + require.True(t, remoteChecksum.IsEqual(&localChecksum)) + + _ = failpoint.Enable("github.com/pingcap/tidb/disttask/importinto/errWhenChecksum", `return(true)`) + defer func() { + _ = failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/errWhenChecksum") + }() + remoteChecksum, err = importinto.TestChecksumTable(ctx, mgr, taskMeta, logutil.BgLogger()) + require.NoError(t, err) + require.True(t, remoteChecksum.IsEqual(&localChecksum)) +} diff --git a/executor/import_into.go b/executor/import_into.go new file mode 100644 index 0000000000000..92f16fb13f611 --- /dev/null +++ b/executor/import_into.go @@ -0,0 +1,302 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "sync/atomic" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/disttask/framework/proto" + fstorage "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/disttask/importinto" + "github.com/pingcap/tidb/executor/asyncloaddata" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/mysql" + plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/privilege" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/sessiontxn" + "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/dbterror/exeerrors" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +var ( + // TestDetachedTaskFinished is a flag for test. + TestDetachedTaskFinished atomic.Bool + // TestCancelFunc for test. + TestCancelFunc context.CancelFunc +) + +const unknownImportedRowCount = -1 + +// ImportIntoExec represents a IMPORT INTO executor. +type ImportIntoExec struct { + baseExecutor + userSctx sessionctx.Context + importPlan *importer.Plan + controller *importer.LoadDataController + stmt string + + dataFilled bool +} + +var ( + _ Executor = (*ImportIntoExec)(nil) +) + +func newImportIntoExec(b baseExecutor, userSctx sessionctx.Context, plan *plannercore.ImportInto, tbl table.Table) ( + *ImportIntoExec, error) { + importPlan, err := importer.NewImportPlan(userSctx, plan, tbl) + if err != nil { + return nil, err + } + astArgs := importer.ASTArgsFromImportPlan(plan) + controller, err := importer.NewLoadDataController(importPlan, tbl, astArgs) + if err != nil { + return nil, err + } + return &ImportIntoExec{ + baseExecutor: b, + userSctx: userSctx, + importPlan: importPlan, + controller: controller, + stmt: plan.Stmt, + }, nil +} + +// Next implements the Executor Next interface. +func (e *ImportIntoExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { + req.GrowAndReset(e.maxChunkSize) + ctx = kv.WithInternalSourceType(ctx, kv.InternalImportInto) + if e.dataFilled { + // need to return an empty req to indicate all results have been written + return nil + } + if err2 := e.controller.InitDataFiles(ctx); err2 != nil { + return err2 + } + + // must use a new session to pre-check, else the stmt in show processlist will be changed. + newSCtx, err2 := CreateSession(e.userSctx) + if err2 != nil { + return err2 + } + defer CloseSession(newSCtx) + sqlExec := newSCtx.(sqlexec.SQLExecutor) + if err2 = e.controller.CheckRequirements(ctx, sqlExec); err2 != nil { + return err2 + } + + failpoint.Inject("cancellableCtx", func() { + // KILL is not implemented in testkit, so we use a fail-point to simulate it. + newCtx, cancel := context.WithCancel(ctx) + ctx = newCtx + TestCancelFunc = cancel + }) + // todo: we don't need Job now, remove it later. + parentCtx := ctx + if e.controller.Detached { + parentCtx = context.Background() + } + group, groupCtx := errgroup.WithContext(parentCtx) + param := &importer.JobImportParam{ + Job: &asyncloaddata.Job{}, + Group: group, + GroupCtx: groupCtx, + Done: make(chan struct{}), + Progress: asyncloaddata.NewProgress(false), + } + distImporter, err := e.getJobImporter(ctx, param) + if err != nil { + return err + } + defer func() { + _ = distImporter.Close() + }() + param.Progress.SourceFileSize = e.controller.TotalFileSize + jobID, task, err := distImporter.SubmitTask(ctx) + if err != nil { + return err + } + + if e.controller.Detached { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalImportInto) + se, err := CreateSession(e.userSctx) + if err != nil { + return err + } + go func() { + defer CloseSession(se) + // error is stored in system table, so we can ignore it here + //nolint: errcheck + _ = e.doImport(ctx, se, distImporter, task) + failpoint.Inject("testDetachedTaskFinished", func() { + TestDetachedTaskFinished.Store(true) + }) + }() + return e.fillJobInfo(ctx, jobID, req) + } + if err = e.doImport(ctx, e.userSctx, distImporter, task); err != nil { + return err + } + return e.fillJobInfo(ctx, jobID, req) +} + +func (e *ImportIntoExec) fillJobInfo(ctx context.Context, jobID int64, req *chunk.Chunk) error { + e.dataFilled = true + // we use globalTaskManager to get job, user might not have the privilege to system tables. + globalTaskManager, err := fstorage.GetTaskManager() + if err != nil { + return err + } + var info *importer.JobInfo + if err = globalTaskManager.WithNewSession(func(se sessionctx.Context) error { + sqlExec := se.(sqlexec.SQLExecutor) + var err2 error + info, err2 = importer.GetJob(ctx, sqlExec, jobID, e.ctx.GetSessionVars().User.String(), false) + return err2 + }); err != nil { + return err + } + fillOneImportJobInfo(info, req, unknownImportedRowCount) + return nil +} + +func (e *ImportIntoExec) getJobImporter(ctx context.Context, param *importer.JobImportParam) (*importinto.DistImporter, error) { + importFromServer, err := storage.IsLocalPath(e.controller.Path) + if err != nil { + // since we have checked this during creating controller, this should not happen. + return nil, exeerrors.ErrLoadDataInvalidURI.FastGenByArgs(err.Error()) + } + logutil.Logger(ctx).Info("get job importer", zap.Stringer("param", e.controller.Parameters), + zap.Bool("dist-task-enabled", variable.EnableDistTask.Load())) + if importFromServer { + ecp, err2 := e.controller.PopulateChunks(ctx) + if err2 != nil { + return nil, err2 + } + return importinto.NewDistImporterServerFile(param, e.importPlan, e.stmt, ecp, e.controller.TotalFileSize) + } + // if tidb_enable_dist_task=true, we import distributively, otherwise we import on current node. + if variable.EnableDistTask.Load() { + return importinto.NewDistImporter(param, e.importPlan, e.stmt, e.controller.TotalFileSize) + } + return importinto.NewDistImporterCurrNode(param, e.importPlan, e.stmt, e.controller.TotalFileSize) +} + +func (e *ImportIntoExec) doImport(ctx context.Context, se sessionctx.Context, distImporter *importinto.DistImporter, task *proto.Task) error { + distImporter.ImportTask(task) + group := distImporter.Param().Group + err := group.Wait() + // when user KILL the connection, the ctx will be canceled, we need to cancel the import job. + if errors.Cause(err) == context.Canceled { + globalTaskManager, err2 := fstorage.GetTaskManager() + if err2 != nil { + return err2 + } + // use background, since ctx is canceled already. + return cancelImportJob(context.Background(), globalTaskManager, distImporter.JobID()) + } + if err2 := flushStats(ctx, se, e.importPlan.TableInfo.ID, distImporter.Result()); err2 != nil { + logutil.Logger(ctx).Error("flush stats failed", zap.Error(err2)) + } + return err +} + +// ImportIntoActionExec represents a import into action executor. +type ImportIntoActionExec struct { + baseExecutor + tp ast.ImportIntoActionTp + jobID int64 +} + +var ( + _ Executor = (*ImportIntoActionExec)(nil) +) + +// Next implements the Executor Next interface. +func (e *ImportIntoActionExec) Next(ctx context.Context, _ *chunk.Chunk) error { + ctx = kv.WithInternalSourceType(ctx, kv.InternalImportInto) + + var hasSuperPriv bool + if pm := privilege.GetPrivilegeManager(e.ctx); pm != nil { + hasSuperPriv = pm.RequestVerification(e.ctx.GetSessionVars().ActiveRoles, "", "", "", mysql.SuperPriv) + } + // we use sessionCtx from GetTaskManager, user ctx might not have enough privileges. + globalTaskManager, err := fstorage.GetTaskManager() + if err != nil { + return err + } + if err = e.checkPrivilegeAndStatus(ctx, globalTaskManager, hasSuperPriv); err != nil { + return err + } + + logutil.Logger(ctx).Info("import into action", zap.Int64("jobID", e.jobID), zap.Any("action", e.tp)) + return cancelImportJob(ctx, globalTaskManager, e.jobID) +} + +func (e *ImportIntoActionExec) checkPrivilegeAndStatus(ctx context.Context, manager *fstorage.TaskManager, hasSuperPriv bool) error { + var info *importer.JobInfo + if err := manager.WithNewSession(func(se sessionctx.Context) error { + exec := se.(sqlexec.SQLExecutor) + var err2 error + info, err2 = importer.GetJob(ctx, exec, e.jobID, e.ctx.GetSessionVars().User.String(), hasSuperPriv) + return err2 + }); err != nil { + return err + } + if !info.CanCancel() { + return exeerrors.ErrLoadDataInvalidOperation.FastGenByArgs("CANCEL") + } + return nil +} + +// flushStats flushes the stats of the table. +func flushStats(ctx context.Context, se sessionctx.Context, tableID int64, result importer.JobImportResult) error { + if err := sessiontxn.NewTxn(ctx, se); err != nil { + return err + } + sessionVars := se.GetSessionVars() + sessionVars.TxnCtxMu.Lock() + defer sessionVars.TxnCtxMu.Unlock() + sessionVars.TxnCtx.UpdateDeltaForTable(tableID, int64(result.Affected), int64(result.Affected), result.ColSizeMap) + se.StmtCommit(ctx) + return se.CommitTxn(ctx) +} + +func cancelImportJob(ctx context.Context, manager *fstorage.TaskManager, jobID int64) error { + // todo: cancel is async operation, we don't wait here now, maybe add a wait syntax later. + // todo: after CANCEL, user can see the job status is Canceled immediately, but the job might still running. + // and the state of framework task might became finished since framework don't force state change DAG when update task. + // todo: add a CANCELLING status? + return manager.WithNewTxn(ctx, func(se sessionctx.Context) error { + exec := se.(sqlexec.SQLExecutor) + if err2 := importer.CancelJob(ctx, exec, jobID); err2 != nil { + return err2 + } + return manager.CancelGlobalTaskByKeySession(se, importinto.TaskKey(jobID)) + }) +} diff --git a/executor/importer/BUILD.bazel b/executor/importer/BUILD.bazel index 52444cf42442a..57ab9d7da9a10 100644 --- a/executor/importer/BUILD.bazel +++ b/executor/importer/BUILD.bazel @@ -58,6 +58,11 @@ go_library( "@com_github_pingcap_log//:log", "@com_github_tikv_client_go_v2//config", "@com_github_tikv_client_go_v2//tikv", +<<<<<<< HEAD +======= + "@com_github_tikv_client_go_v2//util", + "@org_golang_x_exp//slices", +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) "@org_golang_x_sync//errgroup", "@org_uber_go_multierr//:multierr", "@org_uber_go_zap//:zap", diff --git a/executor/importer/table_import.go b/executor/importer/table_import.go index 4c6c8797a52d7..780f1f614779c 100644 --- a/executor/importer/table_import.go +++ b/executor/importer/table_import.go @@ -26,7 +26,6 @@ import ( "time" "github.com/pingcap/errors" - "github.com/pingcap/failpoint" "github.com/pingcap/tidb/br/pkg/lightning/backend" "github.com/pingcap/tidb/br/pkg/lightning/backend/encode" "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" @@ -36,7 +35,6 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/config" "github.com/pingcap/tidb/br/pkg/lightning/log" "github.com/pingcap/tidb/br/pkg/lightning/mydump" - verify "github.com/pingcap/tidb/br/pkg/lightning/verification" "github.com/pingcap/tidb/br/pkg/storage" tidb "github.com/pingcap/tidb/config" tidbkv "github.com/pingcap/tidb/kv" @@ -44,6 +42,10 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/util" +<<<<<<< HEAD +======= + "github.com/pingcap/tidb/util/syncutil" +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) "go.uber.org/multierr" "go.uber.org/zap" ) @@ -278,6 +280,7 @@ func (ti *TableImporter) getKVEncoder(chunk *checkpoints.ChunkCheckpoint) (kvEnc return newTableKVEncoder(cfg, ti.ColumnAssignments, ti.ColumnsAndUserVars, ti.FieldMappings, ti.InsertColumns) } +<<<<<<< HEAD func (ti *TableImporter) importTable(ctx context.Context) error { // todo: pause GC if we need duplicate detection // todo: register task to pd? @@ -331,6 +334,8 @@ func (ti *TableImporter) checksumTable(ctx context.Context) error { return nil } +======= +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) // PopulateChunks populates chunks from table regions. // in dist framework, this should be done in the tidb node which is responsible for splitting job into subtasks // then table-importer handles data belongs to the subtask. diff --git a/tests/realtikvtest/importintotest/job_test.go b/tests/realtikvtest/importintotest/job_test.go new file mode 100644 index 0000000000000..82397c946fa8e --- /dev/null +++ b/tests/realtikvtest/importintotest/job_test.go @@ -0,0 +1,635 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importintotest + +import ( + "context" + "fmt" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/docker/go-units" + "github.com/fsouza/fake-gcs-server/fakestorage" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/lightning/config" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/disttask/framework/scheduler" + "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/disttask/importinto" + "github.com/pingcap/tidb/executor" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/util/dbterror/exeerrors" +) + +func (s *mockGCSSuite) compareJobInfoWithoutTime(jobInfo *importer.JobInfo, row []interface{}) { + s.Equal(strconv.Itoa(int(jobInfo.ID)), row[0]) + + urlExpected, err := url.Parse(jobInfo.Parameters.FileLocation) + s.NoError(err) + urlGot, err := url.Parse(fmt.Sprintf("%v", row[1])) + s.NoError(err) + // order of query parameters might change + s.Equal(urlExpected.Query(), urlGot.Query()) + urlExpected.RawQuery, urlGot.RawQuery = "", "" + s.Equal(urlExpected.String(), urlGot.String()) + + s.Equal(utils.EncloseDBAndTable(jobInfo.TableSchema, jobInfo.TableName), row[2]) + s.Equal(strconv.Itoa(int(jobInfo.TableID)), row[3]) + s.Equal(jobInfo.Step, row[4]) + s.Equal(jobInfo.Status, row[5]) + s.Equal(units.HumanSize(float64(jobInfo.SourceFileSize)), row[6]) + if jobInfo.Summary == nil { + s.Equal("", row[7].(string)) + } else { + s.Equal(strconv.Itoa(int(jobInfo.Summary.ImportedRows)), row[7]) + } + s.Regexp(jobInfo.ErrorMessage, row[8]) + s.Equal(jobInfo.CreatedBy, row[12]) +} + +func (s *mockGCSSuite) TestShowJob() { + s.tk.MustExec("delete from mysql.tidb_import_jobs") + s.prepareAndUseDB("test_show_job") + s.tk.MustExec("CREATE TABLE t1 (i INT PRIMARY KEY);") + s.tk.MustExec("CREATE TABLE t2 (i INT PRIMARY KEY);") + s.tk.MustExec("CREATE TABLE t3 (i INT PRIMARY KEY);") + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-show-job", Name: "t.csv"}, + Content: []byte("1\n2"), + }) + s.T().Cleanup(func() { + _ = s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil) + }) + // create 2 user which don't have system table privileges + s.tk.MustExec(`DROP USER IF EXISTS 'test_show_job1'@'localhost';`) + s.tk.MustExec(`CREATE USER 'test_show_job1'@'localhost';`) + s.tk.MustExec(`GRANT SELECT,UPDATE,INSERT,DELETE,ALTER on test_show_job.* to 'test_show_job1'@'localhost'`) + s.tk.MustExec(`DROP USER IF EXISTS 'test_show_job2'@'localhost';`) + s.tk.MustExec(`CREATE USER 'test_show_job2'@'localhost';`) + s.tk.MustExec(`GRANT SELECT,UPDATE,INSERT,DELETE,ALTER on test_show_job.* to 'test_show_job2'@'localhost'`) + do, err := session.GetDomain(s.store) + s.NoError(err) + tableID1 := do.MustGetTableID(s.T(), "test_show_job", "t1") + tableID2 := do.MustGetTableID(s.T(), "test_show_job", "t2") + tableID3 := do.MustGetTableID(s.T(), "test_show_job", "t3") + + // show non-exists job + err = s.tk.QueryToErr("show import job 9999999999") + s.ErrorIs(err, exeerrors.ErrLoadDataJobNotFound) + + // test show job by id using test_show_job1 + s.enableFailpoint("github.com/pingcap/tidb/executor/importer/setLastImportJobID", `return(true)`) + s.enableFailpoint("github.com/pingcap/tidb/disttask/framework/storage/testSetLastTaskID", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/parser/ast/forceRedactURL", "return(true)") + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_show_job1", Hostname: "localhost"}, nil, nil, nil)) + result1 := s.tk.MustQuery(fmt.Sprintf(`import into t1 FROM 'gs://test-show-job/t.csv?access-key=aaaaaa&secret-access-key=bbbbbb&endpoint=%s'`, + gcsEndpoint)).Rows() + s.Len(result1, 1) + s.tk.MustQuery("select * from t1").Check(testkit.Rows("1", "2")) + rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() + s.Len(rows, 1) + s.Equal(result1, rows) + jobInfo := &importer.JobInfo{ + ID: importer.TestLastImportJobID.Load(), + TableSchema: "test_show_job", + TableName: "t1", + TableID: tableID1, + CreatedBy: "test_show_job1@localhost", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test-show-job/t.csv?access-key=xxxxxx&secret-access-key=xxxxxx&endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "finished", + Step: "", + Summary: &importer.JobSummary{ + ImportedRows: 2, + }, + ErrorMessage: "", + } + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + + // test show job by id using test_show_job2 + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_show_job2", Hostname: "localhost"}, nil, nil, nil)) + result2 := s.tk.MustQuery(fmt.Sprintf(`import into t2 FROM 'gs://test-show-job/t.csv?endpoint=%s'`, gcsEndpoint)).Rows() + s.tk.MustQuery("select * from t2").Check(testkit.Rows("1", "2")) + rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() + s.Len(rows, 1) + s.Equal(result2, rows) + jobInfo.ID = importer.TestLastImportJobID.Load() + jobInfo.TableName = "t2" + jobInfo.TableID = tableID2 + jobInfo.CreatedBy = "test_show_job2@localhost" + jobInfo.Parameters.FileLocation = fmt.Sprintf(`gs://test-show-job/t.csv?endpoint=%s`, gcsEndpoint) + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + rows = s.tk.MustQuery("show import jobs").Rows() + s.Len(rows, 1) + s.Equal(result2, rows) + + // show import jobs with root + checkJobsMatch := func(rows [][]interface{}) { + s.GreaterOrEqual(len(rows), 2) // other cases may create import jobs + var matched int + for _, r := range rows { + if r[0] == result1[0][0] { + s.Equal(result1[0], r) + matched++ + } + if r[0] == result2[0][0] { + s.Equal(result2[0], r) + matched++ + } + } + s.Equal(2, matched) + } + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + rows = s.tk.MustQuery("show import jobs").Rows() + checkJobsMatch(rows) + // show import job by id with root + rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() + s.Len(rows, 1) + s.Equal(result2, rows) + jobInfo.ID = importer.TestLastImportJobID.Load() + jobInfo.TableName = "t2" + jobInfo.TableID = tableID2 + jobInfo.CreatedBy = "test_show_job2@localhost" + jobInfo.Parameters.FileLocation = fmt.Sprintf(`gs://test-show-job/t.csv?endpoint=%s`, gcsEndpoint) + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + + // grant SUPER to test_show_job2, now it can see all jobs + s.tk.MustExec(`GRANT SUPER on *.* to 'test_show_job2'@'localhost'`) + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_show_job2", Hostname: "localhost"}, nil, nil, nil)) + rows = s.tk.MustQuery("show import jobs").Rows() + checkJobsMatch(rows) + + // show running jobs with 2 subtasks + s.enableFailpoint("github.com/pingcap/tidb/disttask/framework/scheduler/syncAfterSubtaskFinish", `return(true)`) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-show-job", Name: "t2.csv"}, + Content: []byte("3\n4"), + }) + backup4 := config.DefaultBatchSize + config.DefaultBatchSize = 1 + s.T().Cleanup(func() { + config.DefaultBatchSize = backup4 + }) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + // wait first subtask finish + <-scheduler.TestSyncChan + + jobInfo = &importer.JobInfo{ + ID: importer.TestLastImportJobID.Load(), + TableSchema: "test_show_job", + TableName: "t3", + TableID: tableID3, + CreatedBy: "test_show_job2@localhost", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test-show-job/t*.csv?access-key=xxxxxx&secret-access-key=xxxxxx&endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 6, + Status: "running", + Step: "importing", + Summary: &importer.JobSummary{ + ImportedRows: 2, + }, + ErrorMessage: "", + } + tk2 := testkit.NewTestKit(s.T(), s.store) + rows = tk2.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() + s.Len(rows, 1) + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + // show processlist, should be redacted too + procRows := tk2.MustQuery("show full processlist").Rows() + + var got bool + for _, r := range procRows { + user := r[1].(string) + sql := r[7].(string) + if user == "test_show_job2" && strings.Contains(sql, "IMPORT INTO") { + s.Contains(sql, "access-key=xxxxxx") + s.Contains(sql, "secret-access-key=xxxxxx") + s.NotContains(sql, "aaaaaa") + s.NotContains(sql, "bbbbbb") + got = true + } + } + s.True(got) + + // resume the scheduler + scheduler.TestSyncChan <- struct{}{} + // wait second subtask finish + <-scheduler.TestSyncChan + rows = tk2.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() + s.Len(rows, 1) + jobInfo.Summary.ImportedRows = 4 + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + // resume the scheduler, need disable failpoint first, otherwise the post-process subtask will be blocked + s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/syncAfterSubtaskFinish")) + scheduler.TestSyncChan <- struct{}{} + }() + s.tk.MustQuery(fmt.Sprintf(`import into t3 FROM 'gs://test-show-job/t*.csv?access-key=aaaaaa&secret-access-key=bbbbbb&endpoint=%s' with thread=1`, gcsEndpoint)) + wg.Wait() + s.tk.MustQuery("select * from t3").Sort().Check(testkit.Rows("1", "2", "3", "4")) +} + +func (s *mockGCSSuite) TestShowDetachedJob() { + s.prepareAndUseDB("show_detached_job") + s.tk.MustExec("CREATE TABLE t1 (i INT PRIMARY KEY);") + s.tk.MustExec("CREATE TABLE t2 (i INT PRIMARY KEY);") + s.tk.MustExec("CREATE TABLE t3 (i INT PRIMARY KEY);") + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-show-detached-job", Name: "t.csv"}, + Content: []byte("1\n2"), + }) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-show-detached-job", Name: "t2.csv"}, + Content: []byte("1\n1"), + }) + do, err := session.GetDomain(s.store) + s.NoError(err) + tableID1 := do.MustGetTableID(s.T(), "show_detached_job", "t1") + tableID2 := do.MustGetTableID(s.T(), "show_detached_job", "t2") + tableID3 := do.MustGetTableID(s.T(), "show_detached_job", "t3") + + jobInfo := &importer.JobInfo{ + TableSchema: "show_detached_job", + TableName: "t1", + TableID: tableID1, + CreatedBy: "root@%", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test-show-detached-job/t.csv?endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "pending", + Step: "", + } + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + result1 := s.tk.MustQuery(fmt.Sprintf(`import into t1 FROM 'gs://test-show-detached-job/t.csv?endpoint=%s' with detached`, + gcsEndpoint)).Rows() + s.Len(result1, 1) + jobID1, err := strconv.Atoi(result1[0][0].(string)) + s.NoError(err) + jobInfo.ID = int64(jobID1) + s.compareJobInfoWithoutTime(jobInfo, result1[0]) + + s.Eventually(func() bool { + rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID1)).Rows() + return rows[0][5] == "finished" + }, 10*time.Second, 500*time.Millisecond) + rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID1)).Rows() + s.Len(rows, 1) + jobInfo.Status = "finished" + jobInfo.Summary = &importer.JobSummary{ + ImportedRows: 2, + } + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + s.tk.MustQuery("select * from t1").Check(testkit.Rows("1", "2")) + + // job fail with checksum mismatch + result2 := s.tk.MustQuery(fmt.Sprintf(`import into t2 FROM 'gs://test-show-detached-job/t2.csv?endpoint=%s' with detached`, + gcsEndpoint)).Rows() + s.Len(result2, 1) + jobID2, err := strconv.Atoi(result2[0][0].(string)) + s.NoError(err) + jobInfo = &importer.JobInfo{ + ID: int64(jobID2), + TableSchema: "show_detached_job", + TableName: "t2", + TableID: tableID2, + CreatedBy: "root@%", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test-show-detached-job/t2.csv?endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "pending", + Step: "", + } + s.compareJobInfoWithoutTime(jobInfo, result2[0]) + s.Eventually(func() bool { + rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID2)).Rows() + return rows[0][5] == "failed" + }, 10*time.Second, 500*time.Millisecond) + rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID2)).Rows() + s.Len(rows, 1) + jobInfo.Status = "failed" + jobInfo.Step = importer.JobStepValidating + jobInfo.ErrorMessage = `\[Lighting:Restore:ErrChecksumMismatch]checksum mismatched remote vs local.*` + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + + // subtask fail with error + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/errorWhenSortChunk", "return(true)") + result3 := s.tk.MustQuery(fmt.Sprintf(`import into t3 FROM 'gs://test-show-detached-job/t.csv?endpoint=%s' with detached`, + gcsEndpoint)).Rows() + s.Len(result3, 1) + jobID3, err := strconv.Atoi(result3[0][0].(string)) + s.NoError(err) + jobInfo = &importer.JobInfo{ + ID: int64(jobID3), + TableSchema: "show_detached_job", + TableName: "t3", + TableID: tableID3, + CreatedBy: "root@%", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test-show-detached-job/t.csv?endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "pending", + Step: "", + } + s.compareJobInfoWithoutTime(jobInfo, result3[0]) + s.Eventually(func() bool { + rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID3)).Rows() + return rows[0][5] == "failed" + }, 10*time.Second, 500*time.Millisecond) + rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID3)).Rows() + s.Len(rows, 1) + jobInfo.Status = "failed" + jobInfo.Step = importer.JobStepImporting + jobInfo.ErrorMessage = `occur an error when sort chunk.*` + s.compareJobInfoWithoutTime(jobInfo, rows[0]) +} + +func (s *mockGCSSuite) TestCancelJob() { + s.prepareAndUseDB("test_cancel_job") + s.tk.MustExec("CREATE TABLE t1 (i INT PRIMARY KEY);") + s.tk.MustExec("CREATE TABLE t2 (i INT PRIMARY KEY);") + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test_cancel_job", Name: "t.csv"}, + Content: []byte("1\n2"), + }) + s.T().Cleanup(func() { + _ = s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil) + }) + s.tk.MustExec(`DROP USER IF EXISTS 'test_cancel_job1'@'localhost';`) + s.tk.MustExec(`CREATE USER 'test_cancel_job1'@'localhost';`) + s.tk.MustExec(`GRANT SELECT,UPDATE,INSERT,DELETE,ALTER on test_cancel_job.* to 'test_cancel_job1'@'localhost'`) + s.tk.MustExec(`DROP USER IF EXISTS 'test_cancel_job2'@'localhost';`) + s.tk.MustExec(`CREATE USER 'test_cancel_job2'@'localhost';`) + s.tk.MustExec(`GRANT SELECT,UPDATE,INSERT,DELETE,ALTER on test_cancel_job.* to 'test_cancel_job2'@'localhost'`) + do, err := session.GetDomain(s.store) + s.NoError(err) + tableID1 := do.MustGetTableID(s.T(), "test_cancel_job", "t1") + tableID2 := do.MustGetTableID(s.T(), "test_cancel_job", "t2") + + // cancel non-exists job + err = s.tk.ExecToErr("cancel import job 9999999999") + s.ErrorIs(err, exeerrors.ErrLoadDataJobNotFound) + + getTask := func(jobID int64) *proto.Task { + globalTaskManager, err := storage.GetTaskManager() + s.NoError(err) + taskKey := importinto.TaskKey(jobID) + globalTask, err := globalTaskManager.GetGlobalTaskByKey(taskKey) + s.NoError(err) + return globalTask + } + + // cancel a running job created by self + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/waitBeforeSortChunk", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/syncAfterJobStarted", "return(true)") + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_cancel_job1", Hostname: "localhost"}, nil, nil, nil)) + result1 := s.tk.MustQuery(fmt.Sprintf(`import into t1 FROM 'gs://test_cancel_job/t.csv?endpoint=%s' with detached`, + gcsEndpoint)).Rows() + s.Len(result1, 1) + jobID1, err := strconv.Atoi(result1[0][0].(string)) + s.NoError(err) + // wait job started + <-importinto.TestSyncChan + // dist framework has bug, the cancelled status might be overridden by running status, + // so we wait it turn running before cancel, see https://github.com/pingcap/tidb/issues/44443 + time.Sleep(3 * time.Second) + s.tk.MustExec(fmt.Sprintf("cancel import job %d", jobID1)) + rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID1)).Rows() + s.Len(rows, 1) + jobInfo := &importer.JobInfo{ + ID: int64(jobID1), + TableSchema: "test_cancel_job", + TableName: "t1", + TableID: tableID1, + CreatedBy: "test_cancel_job1@localhost", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test_cancel_job/t.csv?endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "cancelled", + Step: importer.JobStepImporting, + ErrorMessage: "cancelled by user", + } + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + s.Eventually(func() bool { + task := getTask(int64(jobID1)) + return task.State == proto.TaskStateReverted + }, 10*time.Second, 500*time.Millisecond) + + // cancel again, should fail + s.ErrorIs(s.tk.ExecToErr(fmt.Sprintf("cancel import job %d", jobID1)), exeerrors.ErrLoadDataInvalidOperation) + + // cancel a job created by test_cancel_job1 using test_cancel_job2, should fail + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_cancel_job2", Hostname: "localhost"}, nil, nil, nil)) + s.ErrorIs(s.tk.ExecToErr(fmt.Sprintf("cancel import job %d", jobID1)), core.ErrSpecificAccessDenied) + // cancel by root, should pass privilege check + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + s.ErrorIs(s.tk.ExecToErr(fmt.Sprintf("cancel import job %d", jobID1)), exeerrors.ErrLoadDataInvalidOperation) + + // cancel job in post-process phase, using test_cancel_job2 + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_cancel_job2", Hostname: "localhost"}, nil, nil, nil)) + s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/waitBeforeSortChunk")) + s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/syncAfterJobStarted")) + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/syncBeforePostProcess", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/waitCtxDone", "return(true)") + result2 := s.tk.MustQuery(fmt.Sprintf(`import into t2 FROM 'gs://test_cancel_job/t.csv?endpoint=%s' with detached`, + gcsEndpoint)).Rows() + s.Len(result2, 1) + jobID2, err := strconv.Atoi(result2[0][0].(string)) + s.NoError(err) + // wait job reach post-process phase + <-importinto.TestSyncChan + s.tk.MustExec(fmt.Sprintf("cancel import job %d", jobID2)) + // resume the job + importinto.TestSyncChan <- struct{}{} + rows2 := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID2)).Rows() + s.Len(rows2, 1) + jobInfo = &importer.JobInfo{ + ID: int64(jobID2), + TableSchema: "test_cancel_job", + TableName: "t2", + TableID: tableID2, + CreatedBy: "test_cancel_job2@localhost", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test_cancel_job/t.csv?endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "cancelled", + Step: importer.JobStepValidating, + ErrorMessage: "cancelled by user", + } + s.compareJobInfoWithoutTime(jobInfo, rows2[0]) + globalTaskManager, err := storage.GetTaskManager() + s.NoError(err) + taskKey := importinto.TaskKey(int64(jobID2)) + s.NoError(err) + s.Eventually(func() bool { + globalTask, err2 := globalTaskManager.GetGlobalTaskByKey(taskKey) + s.NoError(err2) + subtasks, err2 := globalTaskManager.GetSubtasksByStep(globalTask.ID, importinto.StepPostProcess) + s.NoError(err2) + s.Len(subtasks, 2) // framework will generate a subtask when canceling + var cancelled bool + for _, st := range subtasks { + if st.State == proto.TaskStateCanceled { + cancelled = true + break + } + } + return globalTask.State == proto.TaskStateReverted && cancelled + }, 5*time.Second, 1*time.Second) + + // todo: enable it when https://github.com/pingcap/tidb/issues/44443 fixed + //// cancel a pending job created by test_cancel_job2 using root + //s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/syncAfterJobStarted")) + //s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/syncBeforeJobStarted", "return(true)") + //result2 := s.tk.MustQuery(fmt.Sprintf(`import into t2 FROM 'gs://test_cancel_job/t.csv?endpoint=%s' with detached`, + // gcsEndpoint)).Rows() + //s.Len(result2, 1) + //jobID2, err := strconv.Atoi(result2[0][0].(string)) + //s.NoError(err) + //// wait job reached to the point before job started + //<-loaddata.TestSyncChan + //s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + //s.tk.MustExec(fmt.Sprintf("cancel import job %d", jobID2)) + //// resume the job + //loaddata.TestSyncChan <- struct{}{} + //rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID2)).Rows() + //s.Len(rows, 1) + //jobInfo = &importer.JobInfo{ + // ID: int64(jobID2), + // TableSchema: "test_cancel_job", + // TableName: "t2", + // TableID: tableID2, + // CreatedBy: "test_cancel_job2@localhost", + // Parameters: importer.ImportParameters{ + // FileLocation: fmt.Sprintf(`gs://test_cancel_job/t.csv?endpoint=%s`, gcsEndpoint), + // Format: importer.DataFormatCSV, + // }, + // SourceFileSize: 3, + // Status: "cancelled", + // Step: "", + // ErrorMessage: "cancelled by user", + //} + //s.compareJobInfoWithoutTime(jobInfo, rows[0]) + //s.Eventually(func() bool { + // task := getTask(int64(jobID2)) + // return task.State == proto.TaskStateReverted + //}, 10*time.Second, 500*time.Millisecond) +} + +func (s *mockGCSSuite) TestJobFailWhenDispatchSubtask() { + s.prepareAndUseDB("fail_job_after_import") + s.tk.MustExec("CREATE TABLE t1 (i INT PRIMARY KEY);") + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "fail_job_after_import", Name: "t.csv"}, + Content: []byte("1\n2"), + }) + do, err := session.GetDomain(s.store) + s.NoError(err) + tableID1 := do.MustGetTableID(s.T(), "fail_job_after_import", "t1") + + jobInfo := &importer.JobInfo{ + TableSchema: "fail_job_after_import", + TableName: "t1", + TableID: tableID1, + CreatedBy: "root@%", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://fail_job_after_import/t.csv?endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "failed", + Step: importer.JobStepValidating, + ErrorMessage: "injected error after StepImport", + } + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/failWhenDispatchPostProcessSubtask", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/executor/importer/setLastImportJobID", `return(true)`) + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + err = s.tk.QueryToErr(fmt.Sprintf(`import into t1 FROM 'gs://fail_job_after_import/t.csv?endpoint=%s'`, gcsEndpoint)) + s.ErrorContains(err, "injected error after StepImport") + result1 := s.tk.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() + s.Len(result1, 1) + jobID1, err := strconv.Atoi(result1[0][0].(string)) + s.NoError(err) + jobInfo.ID = int64(jobID1) + s.compareJobInfoWithoutTime(jobInfo, result1[0]) +} + +func (s *mockGCSSuite) TestKillBeforeFinish() { + s.cleanupSysTables() + s.tk.MustExec("DROP DATABASE IF EXISTS kill_job;") + s.tk.MustExec("CREATE DATABASE kill_job;") + s.tk.MustExec(`CREATE TABLE kill_job.t (a INT, b INT, c int);`) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-load", Name: "t-1.tsv"}, + Content: []byte("1,11,111"), + }) + + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/syncBeforeSortChunk", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/executor/cancellableCtx", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/executor/importer/setLastImportJobID", `return(true)`) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + sql := fmt.Sprintf(`IMPORT INTO kill_job.t FROM 'gs://test-load/t-*.tsv?endpoint=%s'`, gcsEndpoint) + err := s.tk.QueryToErr(sql) + s.ErrorIs(errors.Cause(err), context.Canceled) + }() + // wait for the task reach sort chunk + <-importinto.TestSyncChan + // cancel the job + executor.TestCancelFunc() + // continue the execution + importinto.TestSyncChan <- struct{}{} + wg.Wait() + jobID := importer.TestLastImportJobID.Load() + rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID)).Rows() + s.Len(rows, 1) + s.Equal("cancelled", rows[0][5]) + globalTaskManager, err := storage.GetTaskManager() + s.NoError(err) + taskKey := importinto.TaskKey(jobID) + s.NoError(err) + s.Eventually(func() bool { + globalTask, err2 := globalTaskManager.GetGlobalTaskByKey(taskKey) + s.NoError(err2) + return globalTask.State == proto.TaskStateReverted + }, 5*time.Second, 1*time.Second) +}