From 9fad344aac59214359d5db333faaed1aa142420a Mon Sep 17 00:00:00 2001 From: ekexium Date: Mon, 18 Dec 2023 15:43:53 +0800 Subject: [PATCH] txn: make load data transactional (#49079) ref pingcap/tidb#49078 --- pkg/executor/load_data.go | 86 +++++--- pkg/executor/test/loaddatatest/BUILD.bazel | 2 +- .../test/loaddatatest/load_data_test.go | 164 +++++++-------- .../test/loadremotetest/one_csv_test.go | 41 ++++ pkg/executor/test/writetest/write_test.go | 53 ++--- pkg/privilege/privileges/privileges_test.go | 3 +- pkg/server/conn.go | 199 +++++++++--------- .../testserverclient/server_client.go | 199 ++++++++++++++++++ pkg/server/tests/tidb_serial_test.go | 6 + 9 files changed, 515 insertions(+), 238 deletions(-) diff --git a/pkg/executor/load_data.go b/pkg/executor/load_data.go index f96a70ceb99cf..15fbcbc8a6048 100644 --- a/pkg/executor/load_data.go +++ b/pkg/executor/load_data.go @@ -48,16 +48,64 @@ import ( "golang.org/x/sync/errgroup" ) +// LoadDataVarKey is a variable key for load data. +const LoadDataVarKey loadDataVarKeyType = 0 + +// LoadDataReaderBuilderKey stores the reader channel that reads from the connection. +const LoadDataReaderBuilderKey loadDataVarKeyType = 1 + var ( taskQueueSize = 16 // the maximum number of pending tasks to commit in queue ) +// LoadDataReaderBuilder is a function type that builds a reader from a file path. +type LoadDataReaderBuilder func(filepath string) ( + r io.ReadCloser, err error, +) + // LoadDataExec represents a load data executor. type LoadDataExec struct { exec.BaseExecutor FileLocRef ast.FileLocRefTp loadDataWorker *LoadDataWorker + + // fields for loading local file + infileReader io.ReadCloser +} + +// Open implements the Executor interface. +func (e *LoadDataExec) Open(_ context.Context) error { + if rb, ok := e.Ctx().Value(LoadDataReaderBuilderKey).(LoadDataReaderBuilder); ok { + var err error + e.infileReader, err = rb(e.loadDataWorker.GetInfilePath()) + if err != nil { + return err + } + } + return nil +} + +// Close implements the Executor interface. +func (e *LoadDataExec) Close() error { + return e.closeLocalReader(nil) +} + +func (e *LoadDataExec) closeLocalReader(originalErr error) error { + err := originalErr + if e.infileReader != nil { + if err2 := e.infileReader.Close(); err2 != nil { + logutil.BgLogger().Error( + "close local reader failed", zap.Error(err2), + zap.NamedError("original error", originalErr), + ) + if err == nil { + err = err2 + } + } + e.infileReader = nil + } + return err } // Next implements the Executor Next interface. @@ -66,14 +114,17 @@ func (e *LoadDataExec) Next(ctx context.Context, _ *chunk.Chunk) (err error) { case ast.FileLocServerOrRemote: return e.loadDataWorker.loadRemote(ctx) case ast.FileLocClient: - // let caller use handleFileTransInConn to read data in this connection + // This is for legacy test only + // TODO: adjust tests to remove LoadDataVarKey sctx := e.loadDataWorker.UserSctx - val := sctx.Value(LoadDataVarKey) - if val != nil { - sctx.SetValue(LoadDataVarKey, nil) - return errors.New("previous load data option wasn't closed normally") - } sctx.SetValue(LoadDataVarKey, e.loadDataWorker) + + err = e.loadDataWorker.LoadLocal(ctx, e.infileReader) + if err != nil { + logutil.Logger(ctx).Error("load local data failed", zap.Error(err)) + err = e.closeLocalReader(err) + return err + } } return nil } @@ -145,6 +196,10 @@ func (e *LoadDataWorker) loadRemote(ctx context.Context) error { // LoadLocal reads from client connection and do load data job. func (e *LoadDataWorker) LoadLocal(ctx context.Context, r io.ReadCloser) error { + if r == nil { + return errors.New("load local data, reader is nil") + } + compressTp := mydump.ParseCompressionOnFileExtension(e.GetInfilePath()) compressTp2, err := mydump.ToStorageCompressType(compressTp) if err != nil { @@ -172,11 +227,6 @@ func (e *LoadDataWorker) load(ctx context.Context, readerInfos []importer.LoadDa commitTaskCh := make(chan commitTask, taskQueueSize) // commitWork goroutines -> done -> UpdateJobProgress goroutine - // TODO: support explicit transaction and non-autocommit - if err = sessiontxn.NewTxn(groupCtx, e.UserSctx); err != nil { - return err - } - // processOneStream goroutines. group.Go(func() error { err2 := encoder.processStream(groupCtx, readerInfoCh, commitTaskCh) @@ -530,16 +580,6 @@ func (w *commitWorker) commitWork(ctx context.Context, inCh <-chan commitTask) ( zap.Stack("stack")) err = util.GetRecoverError(r) } - - if err != nil { - background := context.Background() - w.Ctx().StmtRollback(background, false) - w.Ctx().RollbackTxn(background) - } else { - if err = w.Ctx().CommitTxn(ctx); err != nil { - logutil.Logger(ctx).Error("commit error refresh", zap.Error(err)) - } - } }() var ( @@ -578,7 +618,6 @@ func (w *commitWorker) commitOneTask(ctx context.Context, task commitTask) error failpoint.Inject("commitOneTaskErr", func() { failpoint.Return(errors.New("mock commit one task error")) }) - w.Ctx().StmtCommit(ctx) return nil } @@ -734,6 +773,3 @@ type loadDataVarKeyType int func (loadDataVarKeyType) String() string { return "load_data_var" } - -// LoadDataVarKey is a variable key for load data. -const LoadDataVarKey loadDataVarKeyType = 0 diff --git a/pkg/executor/test/loaddatatest/BUILD.bazel b/pkg/executor/test/loaddatatest/BUILD.bazel index 4f0f27f54363e..7c42b26c4e066 100644 --- a/pkg/executor/test/loaddatatest/BUILD.bazel +++ b/pkg/executor/test/loaddatatest/BUILD.bazel @@ -9,7 +9,7 @@ go_test( ], flaky = True, race = "on", - shard_count = 10, + shard_count = 11, deps = [ "//br/pkg/lightning/mydump", "//pkg/config", diff --git a/pkg/executor/test/loaddatatest/load_data_test.go b/pkg/executor/test/loaddatatest/load_data_test.go index 1284475340815..bb9f03d7a32d5 100644 --- a/pkg/executor/test/loaddatatest/load_data_test.go +++ b/pkg/executor/test/loaddatatest/load_data_test.go @@ -15,7 +15,8 @@ package loaddatatest import ( - "context" + "fmt" + "io" "testing" "github.com/pingcap/tidb/br/pkg/lightning/mydump" @@ -34,25 +35,26 @@ type testCase struct { func checkCases( tests []testCase, - ld *executor.LoadDataWorker, + loadSQL string, t *testing.T, tk *testkit.TestKit, ctx sessionctx.Context, selectSQL, deleteSQL string, ) { for _, tt := range tests { - parser, err := mydump.NewCSVParser( - context.Background(), - ld.GetController().GenerateCSVConfig(), - mydump.NewStringReader(string(tt.data)), - 1, - nil, - false, - nil) - require.NoError(t, err) - - err = ld.TestLoadLocal(parser) - require.NoError(t, err) + var reader io.ReadCloser = mydump.NewStringReader(string(tt.data)) + var readerBuilder executor.LoadDataReaderBuilder = func(_ string) ( + r io.ReadCloser, err error, + ) { + return reader, nil + } + + ctx.SetValue(executor.LoadDataReaderBuilderKey, readerBuilder) + tk.MustExec(loadSQL) + warnings := tk.Session().GetSessionVars().StmtCtx.GetWarnings() + for _, w := range warnings { + fmt.Printf("warnnig: %#v\n", w.Err.Error()) + } require.Equal(t, tt.expectedMsg, tk.Session().LastMessage(), tt.expected) tk.MustQuery(selectSQL).Check(testkit.RowsWithSep("|", tt.expected...)) tk.MustExec(deleteSQL) @@ -80,7 +82,7 @@ func TestLoadDataInitParam(t *testing.T) { // null def values testFunc := func(sql string, expectedNullDef []string, expectedNullOptEnclosed bool) { - require.NoError(t, tk.ExecToErr(sql)) + require.ErrorContains(t, tk.ExecToErr(sql), "reader is nil") defer ctx.SetValue(executor.LoadDataVarKey, nil) ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) require.True(t, ok) @@ -102,11 +104,26 @@ func TestLoadDataInitParam(t *testing.T) { []string{"NULL"}, false) // positive case - require.NoError(t, tk.ExecToErr("load data local infile '/a' format 'sql file' into table load_data_test")) + require.ErrorContains( + t, tk.ExecToErr( + "load data local infile '/a' format 'sql file' into table"+ + " load_data_test", + ), "reader is nil", + ) ctx.SetValue(executor.LoadDataVarKey, nil) - require.NoError(t, tk.ExecToErr("load data local infile '/a' into table load_data_test fields terminated by 'a'")) + require.ErrorContains( + t, tk.ExecToErr( + "load data local infile '/a' into table load_data_test fields"+ + " terminated by 'a'", + ), "reader is nil", + ) ctx.SetValue(executor.LoadDataVarKey, nil) - require.NoError(t, tk.ExecToErr("load data local infile '/a' format 'delimited data' into table load_data_test fields terminated by 'a'")) + require.ErrorContains( + t, tk.ExecToErr( + "load data local infile '/a' format 'delimited data' into"+ + " table load_data_test fields terminated by 'a'", + ), "reader is nil", + ) ctx.SetValue(executor.LoadDataVarKey, nil) // According to https://dev.mysql.com/doc/refman/8.0/en/load-data.html , fixed-row format should be used when fields @@ -130,12 +147,8 @@ func TestLoadData(t *testing.T) { tk.MustExec(createSQL) err = tk.ExecToErr("load data infile '/tmp/nonexistence.csv' into table load_data_test") require.Error(t, err) - tk.MustExec("load data local infile '/tmp/nonexistence.csv' ignore into table load_data_test") + loadSQL := "load data local infile '/tmp/nonexistence.csv' ignore into table load_data_test" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" @@ -164,10 +177,11 @@ func TestLoadData(t *testing.T) { {[]byte("\t2\t3\t4\t5\n"), []string{"10|2|3|4"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 2"}, {[]byte("\t2\t34\t5\n"), []string{"11|2|34|5"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) // lines starting symbol is "" and terminated symbol length is 2, ReadOneBatchRows returns data is nil - ld.GetController().LinesTerminatedBy = "||" + loadSQL = "load data local infile '/tmp/nonexistence." + + "csv' ignore into table load_data_test lines terminated by '||'" tests = []testCase{ {[]byte("0\t2\t3\t4\t5||"), []string{"12|2|3|4"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, {[]byte("1\t2\t3\t4\t5||"), []string{"1|2|3|4"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, @@ -179,12 +193,11 @@ func TestLoadData(t *testing.T) { []string{"4|2|3|4", "5|22|33|", "6|222||"}, "Records: 3 Deleted: 0 Skipped: 0 Warnings: 3"}, {[]byte("6\t2\t34\t5||"), []string{"6|2|34|5"}, trivialMsg}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) // fields and lines aren't default, ReadOneBatchRows returns data is nil - ld.GetController().FieldsTerminatedBy = "\\" - ld.GetController().LinesStartingBy = "xxx" - ld.GetController().LinesTerminatedBy = "|!#^" + loadSQL = "load data local infile '/tmp/nonexistence.csv' " + + `ignore into table load_data_test fields terminated by '\\' lines starting by 'xxx' terminated by '|!#^'` tests = []testCase{ {[]byte("xxx|!#^"), []string{"13|||"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 2"}, {[]byte("xxx\\|!#^"), []string{"14|0||"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 3"}, @@ -219,7 +232,7 @@ func TestLoadData(t *testing.T) { []string{"25|2|3|4", "27|222||"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 2"}, {[]byte("xxx\\2\\34\\5|!#^"), []string{"28|2|34|5"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) // TODO: not support it now // lines starting symbol is the same as terminated symbol, ReadOneBatchRows returns data is nil @@ -258,21 +271,25 @@ func TestLoadData(t *testing.T) { //checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) // test line terminator in field quoter - ld.GetController().LinesTerminatedBy = "\n" - ld.GetController().FieldsEnclosedBy = `"` + loadSQL = "load data local infile '/tmp/nonexistence.csv' " + + "ignore into table load_data_test " + + "fields terminated by '\\\\' enclosed by '\\\"' " + + "lines starting by 'xxx' terminated by '\\n'" tests = []testCase{ {[]byte("xxx1\\1\\\"2\n\"\\3\nxxx4\\4\\\"5\n5\"\\6"), []string{"1|1|2\n|3", "4|4|5\n5|6"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 0"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) - ld.GetController().LinesTerminatedBy = "#\n" - ld.GetController().FieldsTerminatedBy = "#" + loadSQL = "load data local infile '/tmp/nonexistence.csv' " + + "ignore into table load_data_test " + + "fields terminated by '#' enclosed by '\\\"' " + + "lines starting by 'xxx' terminated by '#\\n'" tests = []testCase{ {[]byte("xxx1#\nxxx2#\n"), []string{"1|||", "2|||"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 2"}, {[]byte("xxx1#2#3#4#\nnxxx2#3#4#5#\n"), []string{"1|2|3|4", "2|3|4|5"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 0"}, {[]byte("xxx1#2#\"3#\"#\"4\n\"#\nxxx2#3#\"#4#\n\"#5#\n"), []string{"1|2|3#|4", "2|3|#4#\n|5"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 0"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) // TODO: now support it now //ld.LinesInfo.Terminated = "#" @@ -293,12 +310,8 @@ func TestLoadDataEscape(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test; drop table if exists load_data_test;") tk.MustExec("CREATE TABLE load_data_test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test") + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) // test escape tests := []testCase{ // data1 = nil, data2 != nil @@ -314,7 +327,7 @@ func TestLoadDataEscape(t *testing.T) { } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } // TestLoadDataSpecifiedColumns reuse TestLoadDataEscape's test case :-) @@ -324,12 +337,8 @@ func TestLoadDataSpecifiedColumns(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test; drop table if exists load_data_test;") tk.MustExec(`create table load_data_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int, c2 varchar(255) default "def", c3 int default 0);`) - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test (c1, c2)") + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test (c1, c2)" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) // test tests := []testCase{ {[]byte("7\ta string\n"), []string{"1|7|a string|0"}, trivialMsg}, @@ -342,7 +351,7 @@ func TestLoadDataSpecifiedColumns(t *testing.T) { } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } func TestLoadDataIgnoreLines(t *testing.T) { @@ -350,19 +359,15 @@ func TestLoadDataIgnoreLines(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test; drop table if exists load_data_test;") tk.MustExec("CREATE TABLE load_data_test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test ignore 1 lines") + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test ignore 1 lines" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) tests := []testCase{ {[]byte("1\tline1\n2\tline2\n"), []string{"2|line2"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0"}, {[]byte("1\tline1\n2\tline2\n3\tline3\n"), []string{"2|line2", "3|line3"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 0"}, } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } func TestLoadDataNULL(t *testing.T) { @@ -374,13 +379,9 @@ func TestLoadDataNULL(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test; drop table if exists load_data_test;") tk.MustExec("CREATE TABLE load_data_test (id VARCHAR(20), value VARCHAR(20)) CHARACTER SET utf8") - tk.MustExec(`load data local infile '/tmp/nonexistence.csv' into table load_data_test -FIELDS TERMINATED BY ',' ENCLOSED BY '"' LINES TERMINATED BY '\n';`) + loadSQL := `load data local infile '/tmp/nonexistence.csv' into table load_data_test +FIELDS TERMINATED BY ',' ENCLOSED BY '"' LINES TERMINATED BY '\n';` ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) tests := []testCase{ { []byte(`NULL,"NULL" @@ -392,7 +393,7 @@ FIELDS TERMINATED BY ',' ENCLOSED BY '"' LINES TERMINATED BY '\n';`) } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } func TestLoadDataReplace(t *testing.T) { @@ -401,19 +402,15 @@ func TestLoadDataReplace(t *testing.T) { tk.MustExec("USE test; DROP TABLE IF EXISTS load_data_replace;") tk.MustExec("CREATE TABLE load_data_replace (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL)") tk.MustExec("INSERT INTO load_data_replace VALUES(1,'val 1'),(2,'val 2')") - tk.MustExec("LOAD DATA LOCAL INFILE '/tmp/nonexistence.csv' REPLACE INTO TABLE load_data_replace") + loadSQL := "LOAD DATA LOCAL INFILE '/tmp/nonexistence.csv' REPLACE INTO TABLE load_data_replace" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) tests := []testCase{ {[]byte("1\tline1\n2\tline2\n"), []string{"1|line1", "2|line2"}, "Records: 2 Deleted: 2 Skipped: 0 Warnings: 0"}, {[]byte("2\tnew line2\n3\tnew line3\n"), []string{"1|line1", "2|new line2", "3|new line3"}, "Records: 2 Deleted: 1 Skipped: 0 Warnings: 0"}, } deleteSQL := "DO 1" selectSQL := "TABLE load_data_replace;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } // TestLoadDataOverflowBigintUnsigned related to issue 6360 @@ -422,19 +419,15 @@ func TestLoadDataOverflowBigintUnsigned(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test; drop table if exists load_data_test;") tk.MustExec("CREATE TABLE load_data_test (a bigint unsigned);") - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test") + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) tests := []testCase{ {[]byte("-1\n-18446744073709551615\n-18446744073709551616\n"), []string{"0", "0", "0"}, "Records: 3 Deleted: 0 Skipped: 0 Warnings: 3"}, {[]byte("-9223372036854775809\n18446744073709551616\n"), []string{"0", "18446744073709551615"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 2"}, } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } func TestLoadDataWithUppercaseUserVars(t *testing.T) { @@ -442,19 +435,15 @@ func TestLoadDataWithUppercaseUserVars(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test; drop table if exists load_data_test;") tk.MustExec("CREATE TABLE load_data_test (a int, b int);") - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test (@V1)" + - " set a = @V1, b = @V1*100") + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test (@V1)" + + " set a = @V1, b = @V1*100" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) tests := []testCase{ {[]byte("1\n2\n"), []string{"1|100", "2|200"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 0"}, } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } func TestLoadDataIntoPartitionedTable(t *testing.T) { @@ -465,14 +454,21 @@ func TestLoadDataIntoPartitionedTable(t *testing.T) { "partition p0 values less than (4)," + "partition p1 values less than (7)," + "partition p2 values less than (11))") - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table range_t fields terminated by ','") ctx := tk.Session().(sessionctx.Context) - ld := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table range_t fields terminated by ','" tests := []testCase{ {[]byte("1,2\n3,4\n5,6\n7,8\n9,10\n"), []string{"1|2", "3|4", "5|6", "7|8", "9|10"}, "Records: 5 Deleted: 0 Skipped: 0 Warnings: 0"}, } deleteSQL := "delete from range_t" selectSQL := "select * from range_t order by a;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) +} + +func TestLoadDataFromServerFile(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table load_data_test (a int)") + err := tk.ExecToErr("load data infile 'remote.csv' into table load_data_test") + require.ErrorContains(t, err, "[executor:8154]Don't support load data from tidb-server's disk.") } diff --git a/pkg/executor/test/loadremotetest/one_csv_test.go b/pkg/executor/test/loadremotetest/one_csv_test.go index 4a021ecebbe56..2fefa018d63fa 100644 --- a/pkg/executor/test/loadremotetest/one_csv_test.go +++ b/pkg/executor/test/loadremotetest/one_csv_test.go @@ -85,6 +85,47 @@ func (s *mockGCSSuite) TestLoadCSV() { s.tk.MustContainErrMsg(sql, "Don't support load data from tidb-server's disk. Or if you want to load local data via client, the path of INFILE '/etc/passwd' needs to specify the clause of LOCAL first") } +func (s *mockGCSSuite) TestLoadCsvInTransaction() { + s.tk.MustExec("DROP DATABASE IF EXISTS load_csv;") + s.tk.MustExec("CREATE DATABASE load_csv;") + s.tk.MustExec("CREATE TABLE load_csv.t (i INT, s varchar(32));") + + s.server.CreateObject( + fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{ + BucketName: "test-load-csv", + Name: "data.csv", + }, + Content: []byte("100, test100\n101, hello\n102, 😄😄😄😄😄\n104, bye"), + }, + ) + + s.tk.MustExec("begin pessimistic") + sql := fmt.Sprintf( + `LOAD DATA INFILE 'gs://test-load-csv/data.csv?endpoint=%s' INTO TABLE load_csv.t `+ + "FIELDS TERMINATED BY ','", + gcsEndpoint, + ) + // test: load data stmt doesn't commit it + s.tk.MustExec("insert into load_csv.t values (1, 'a')") + s.tk.MustExec(sql) + s.tk.MustQuery("select i from load_csv.t order by i").Check( + testkit.Rows( + "1", "100", "101", + "102", "104", + ), + ) + // load data can be rolled back + s.tk.MustExec("rollback") + s.tk.MustQuery("select * from load_csv.t").Check(testkit.Rows()) + + // load data commit + s.tk.MustExec("begin pessimistic") + s.tk.MustExec(sql) + s.tk.MustExec("commit") + s.tk.MustQuery("select i from load_csv.t").Check(testkit.Rows("100", "101", "102", "104")) +} + func (s *mockGCSSuite) TestIgnoreNLines() { s.tk.MustExec("DROP DATABASE IF EXISTS load_csv;") s.tk.MustExec("CREATE DATABASE load_csv;") diff --git a/pkg/executor/test/writetest/write_test.go b/pkg/executor/test/writetest/write_test.go index abfd491894dc5..d57b8b0404a2e 100644 --- a/pkg/executor/test/writetest/write_test.go +++ b/pkg/executor/test/writetest/write_test.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "io" "testing" "github.com/pingcap/tidb/br/pkg/lightning/mydump" @@ -165,25 +166,26 @@ type testCase struct { func checkCases( tests []testCase, - ld *executor.LoadDataWorker, + loadSQL string, t *testing.T, tk *testkit.TestKit, ctx sessionctx.Context, selectSQL, deleteSQL string, ) { for _, tt := range tests { - parser, err := mydump.NewCSVParser( - context.Background(), - ld.GetController().GenerateCSVConfig(), - mydump.NewStringReader(string(tt.data)), - 1, - nil, - false, - nil) - require.NoError(t, err) + var reader io.ReadCloser = mydump.NewStringReader(string(tt.data)) + var readerBuilder executor.LoadDataReaderBuilder = func(_ string) ( + r io.ReadCloser, err error, + ) { + return reader, nil + } - err = ld.TestLoadLocal(parser) - require.NoError(t, err) + ctx.SetValue(executor.LoadDataReaderBuilderKey, readerBuilder) + tk.MustExec(loadSQL) + warnings := tk.Session().GetSessionVars().StmtCtx.GetWarnings() + for _, w := range warnings { + fmt.Printf("warnnig: %#v\n", w.Err.Error()) + } require.Equal(t, tt.expectedMsg, tk.Session().LastMessage(), tt.expected) tk.MustQuery(selectSQL).Check(testkit.RowsWithSep("|", tt.expected...)) tk.MustExec(deleteSQL) @@ -196,12 +198,8 @@ func TestLoadDataMissingColumn(t *testing.T) { tk.MustExec("use test") createSQL := `create table load_data_missing (id int, t timestamp not null)` tk.MustExec(createSQL) - tk.MustExec("load data local infile '/tmp/nonexistence.csv' ignore into table load_data_missing") + loadSQL := "load data local infile '/tmp/nonexistence.csv' ignore into table load_data_missing" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) deleteSQL := "delete from load_data_missing" selectSQL := "select id, hour(t), minute(t) from load_data_missing;" @@ -213,7 +211,7 @@ func TestLoadDataMissingColumn(t *testing.T) { {[]byte(""), nil, "Records: 0 Deleted: 0 Skipped: 0 Warnings: 0"}, {[]byte("12\n"), []string{fmt.Sprintf("12|%v|%v", timeHour, timeMinute)}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) tk.MustExec("alter table load_data_missing add column t2 timestamp null") curTime = types.CurrentTime(mysql.TypeTimestamp) @@ -223,7 +221,7 @@ func TestLoadDataMissingColumn(t *testing.T) { tests = []testCase{ {[]byte("12\n"), []string{fmt.Sprintf("12|%v|%v|", timeHour, timeMinute)}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } func TestIssue18681(t *testing.T) { @@ -233,12 +231,8 @@ func TestIssue18681(t *testing.T) { createSQL := `drop table if exists load_data_test; create table load_data_test (a bit(1),b bit(1),c bit(1),d bit(1));` tk.MustExec(createSQL) - tk.MustExec("load data local infile '/tmp/nonexistence.csv' ignore into table load_data_test") + loadSQL := "load data local infile '/tmp/nonexistence.csv' ignore into table load_data_test" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) deleteSQL := "delete from load_data_test" selectSQL := "select bin(a), bin(b), bin(c), bin(d) from load_data_test;" @@ -254,7 +248,7 @@ func TestIssue18681(t *testing.T) { tests := []testCase{ {[]byte("true\tfalse\t0\t1\n"), []string{"1|0|0|1"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) require.Equal(t, uint16(0), sc.WarningCount()) } @@ -268,13 +262,12 @@ func TestIssue34358(t *testing.T) { tk.MustExec("drop table if exists load_data_test") tk.MustExec("create table load_data_test (a varchar(10), b varchar(10))") - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test ( @v1, @v2 ) set a = @v1, b = @v2") - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - require.NotNil(t, ld) + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test ( @v1, " + + "@v2 ) set a = @v1, b = @v2" checkCases([]testCase{ {[]byte("\\N\n"), []string{"|"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, - }, ld, t, tk, ctx, "select * from load_data_test", "delete from load_data_test") + }, loadSQL, t, tk, ctx, "select * from load_data_test", "delete from load_data_test", + ) } func TestLatch(t *testing.T) { diff --git a/pkg/privilege/privileges/privileges_test.go b/pkg/privilege/privileges/privileges_test.go index a8c39ceda572a..043ae41ca52fe 100644 --- a/pkg/privilege/privileges/privileges_test.go +++ b/pkg/privilege/privileges/privileges_test.go @@ -1003,7 +1003,8 @@ func TestLoadDataPrivilege(t *testing.T) { require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) tk.MustExec(`GRANT INSERT on *.* to 'test_load'@'localhost'`) require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "test_load", Hostname: "localhost"}, nil, nil, nil)) - tk.MustExec("LOAD DATA LOCAL INFILE '/tmp/load_data_priv.csv' INTO TABLE t_load") + err = tk.ExecToErr("LOAD DATA LOCAL INFILE '/tmp/load_data_priv.csv' INTO TABLE t_load") + require.ErrorContains(t, err, "reader is nil") require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) tk.MustExec(`GRANT INSERT on *.* to 'test_load'@'localhost'`) diff --git a/pkg/server/conn.go b/pkg/server/conn.go index 3bdc25e4d9722..43efc6c908974 100644 --- a/pkg/server/conn.go +++ b/pkg/server/conn.go @@ -1552,93 +1552,6 @@ func (cc *clientConn) writeReq(ctx context.Context, filePath string) error { return cc.flush(ctx) } -// handleLoadData does the additional work after processing the 'load data' query. -// It sends client a file path, then reads the file content from client, inserts data into database. -func (cc *clientConn) handleLoadData(ctx context.Context, loadDataWorker *executor.LoadDataWorker) error { - // If the server handles the load data request, the client has to set the ClientLocalFiles capability. - if cc.capability&mysql.ClientLocalFiles == 0 { - return servererr.ErrNotAllowedCommand - } - if loadDataWorker == nil { - return errors.New("load data info is empty") - } - infile := loadDataWorker.GetInfilePath() - err := cc.writeReq(ctx, infile) - if err != nil { - return err - } - - var ( - // use Pipe to convert cc.readPacket to io.Reader - r, w = io.Pipe() - drained bool - wg sync.WaitGroup - ) - wg.Add(1) - go func() { - defer wg.Done() - //nolint: errcheck - defer w.Close() - - var ( - data []byte - err2 error - ) - for { - if len(data) == 0 { - data, err2 = cc.readPacket() - if err2 != nil { - w.CloseWithError(err2) - return - } - // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html - if len(data) == 0 { - drained = true - return - } - } - - n, err3 := w.Write(data) - if err3 != nil { - logutil.Logger(ctx).Error("write data meet error", zap.Error(err3)) - return - } - data = data[n:] - } - }() - - ctx = kv.WithInternalSourceType(ctx, kv.InternalLoadData) - err = loadDataWorker.LoadLocal(ctx, r) - _ = r.Close() - wg.Wait() - - if err != nil { - if !drained { - logutil.Logger(ctx).Info("not drained yet, try reading left data from client connection") - } - // drain the data from client conn util empty packet received, otherwise the connection will be reset - // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html - for !drained { - // check kill flag again, let the draining loop could quit if empty packet could not be received - if atomic.CompareAndSwapUint32(&loadDataWorker.UserSctx.GetSessionVars().SQLKiller.Signal, 1, 0) { - logutil.Logger(ctx).Warn("receiving kill, stop draining data, connection may be reset") - return exeerrors.ErrQueryInterrupted - } - curData, err1 := cc.readPacket() - if err1 != nil { - logutil.Logger(ctx).Error("drain reading left data encounter errors", zap.Error(err1)) - break - } - if len(curData) == 0 { - drained = true - logutil.Logger(ctx).Info("draining finished for error", zap.Error(err)) - break - } - } - } - return err -} - // getDataFromPath gets file contents from file path. func (cc *clientConn) getDataFromPath(ctx context.Context, path string) ([]byte, error) { err := cc.writeReq(ctx, path) @@ -2036,12 +1949,28 @@ func (cc *clientConn) prefetchPointPlanKeys(ctx context.Context, stmts []ast.Stm // The first return value indicates whether the call of handleStmt has no side effect and can be retried. // Currently, the first return value is used to fall back to TiKV when TiFlash is down. -func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, warns []stmtctx.SQLWarn, lastStmt bool) (bool, error) { +func (cc *clientConn) handleStmt( + ctx context.Context, stmt ast.StmtNode, + warns []stmtctx.SQLWarn, lastStmt bool, +) (bool, error) { ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{}) ctx = context.WithValue(ctx, util.RUDetailsCtxKey, util.NewRUDetails()) reg := trace.StartRegion(ctx, "ExecuteStmt") cc.audit(plugin.Starting) + + // if stmt is load data stmt, store the channel that reads from the conn + // into the ctx for executor to use + if s, ok := stmt.(*ast.LoadDataStmt); ok { + if s.FileLocRef == ast.FileLocClient { + err := cc.preprocessLoadDataLocal(ctx) + defer cc.postprocessLoadDataLocal() + if err != nil { + return false, err + } + } + } + rs, err := cc.ctx.ExecuteStmt(ctx, stmt) reg.End() // - If rs is not nil, the statement tracker detachment from session tracker @@ -2051,6 +1980,7 @@ func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, warns [ if rs != nil { defer rs.Close() } + if err != nil { // If error is returned during the planner phase or the executor.Open // phase, the rs will be nil, and StmtCtx.MemTracker StmtCtx.DiskTracker @@ -2088,18 +2018,93 @@ func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, warns [ return false, err } -func (cc *clientConn) handleFileTransInConn(ctx context.Context, status uint16) (bool, error) { - handled := false - loadDataInfo := cc.ctx.Value(executor.LoadDataVarKey) - if loadDataInfo != nil { - handled = true - defer cc.ctx.SetValue(executor.LoadDataVarKey, nil) - //nolint:forcetypeassert - if err := cc.handleLoadData(ctx, loadDataInfo.(*executor.LoadDataWorker)); err != nil { - return handled, err +// Preprocess LOAD DATA. Load data from a local file requires reading from the connection. +// The function pass a builder to build the connection reader to the context, +// which will be used in LoadDataExec. +func (cc *clientConn) preprocessLoadDataLocal(ctx context.Context) error { + if cc.capability&mysql.ClientLocalFiles == 0 { + return servererr.ErrNotAllowedCommand + } + + var readerBuilder executor.LoadDataReaderBuilder = func(filepath string) ( + io.ReadCloser, error, + ) { + err := cc.writeReq(ctx, filepath) + if err != nil { + return nil, err } + + drained := false + r, w := io.Pipe() + + go func() { + var errOccurred error + + defer func() { + if errOccurred != nil { + // Continue reading packets to drain the connection + for !drained { + data, err := cc.readPacket() + if err != nil { + logutil.Logger(ctx).Error( + "drain connection failed in load data", + zap.Error(err), + ) + break + } + if len(data) == 0 { + drained = true + } + } + } + err := w.CloseWithError(errOccurred) + if err != nil { + logutil.Logger(ctx).Error( + "close pipe failed in `load data`", + zap.Error(err), + ) + } + }() + + for { + data, err := cc.readPacket() + if err != nil { + errOccurred = err + return + } + + if len(data) == 0 { + drained = true + return + } + + // Write all content in `data` + for len(data) > 0 { + n, err := w.Write(data) + if err != nil { + errOccurred = err + return + } + data = data[n:] + } + } + }() + + return r, nil } + cc.ctx.SetValue(executor.LoadDataReaderBuilderKey, readerBuilder) + + return nil +} + +func (cc *clientConn) postprocessLoadDataLocal() { + cc.ctx.ClearValue(executor.LoadDataReaderBuilderKey) +} + +func (cc *clientConn) handleFileTransInConn(ctx context.Context, status uint16) (bool, error) { + handled := false + loadStats := cc.ctx.Value(executor.LoadStatsVarKey) if loadStats != nil { handled = true diff --git a/pkg/server/internal/testserverclient/server_client.go b/pkg/server/internal/testserverclient/server_client.go index a05b24b2b8ccb..0222217ef025e 100644 --- a/pkg/server/internal/testserverclient/server_client.go +++ b/pkg/server/internal/testserverclient/server_client.go @@ -28,6 +28,7 @@ import ( "regexp" "strconv" "strings" + "sync" "testing" "time" @@ -1008,6 +1009,204 @@ func columnsAsExpected(t *testing.T, columns []*sql.NullString, expected []strin } } +func (cli *TestServerClient) RunTestLoadDataInTransaction(t *testing.T) { + fp, err := os.CreateTemp("", "load_data_test.csv") + require.NoError(t, err) + path := fp.Name() + + require.NotNil(t, fp) + defer func() { + err = fp.Close() + require.NoError(t, err) + err = os.Remove(path) + require.NoError(t, err) + }() + + _, err = fp.WriteString("1") + require.NoError(t, err) + + // load file in transaction can be rolled back + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadDataInTransaction", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (a int)") + txn, err := dbt.GetDB().Begin() + require.NoError(t, err) + txn.Exec("insert into t values (100)") // `load data` doesn't commit current txn + _, err = txn.Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + rows, err := txn.Query("select * from t") + require.NoError(t, err) + cli.CheckRows(t, rows, "100\n1") + err = txn.Rollback() + require.NoError(t, err) + rows = dbt.MustQuery("select * from t") + cli.CheckRows(t, rows) + }, + ) + + // load file in transaction doesn't commit until the transaction is committed + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadDataInTransaction", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (a int)") + txn, err := dbt.GetDB().Begin() + require.NoError(t, err) + _, err = txn.Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + rows, err := txn.Query("select * from t") + require.NoError(t, err) + cli.CheckRows(t, rows, "1") + err = txn.Commit() + require.NoError(t, err) + rows = dbt.MustQuery("select * from t") + cli.CheckRows(t, rows, "1") + }, + ) + + // load file in auto commit mode should succeed + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadDataInAutoCommit", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (a int)") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t", path)) + txn, err := dbt.GetDB().Begin() + require.NoError(t, err) + rows, _ := txn.Query("select * from t") + cli.CheckRows(t, rows, "1") + }, + ) + + // load file in a pessimistic transaction, + // should acquire locks when after its execution and before it commits. + // The lock should be observed by another transaction that is attempting to acquire the same + // lock. + dbName := "LoadDataInPessimisticTransaction" + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, dbName, func(dbt *testkit.DBTestKit) { + dbt.MustExec("set @@global.tidb_txn_mode = 'pessimistic'") + dbt.MustExec("create table t (a int primary key)") + txn, err := dbt.GetDB().Begin() + require.NoError(t, err) + _, err = txn.Exec(fmt.Sprintf("USE `%s`;", dbName)) + require.NoError(t, err) + _, err = txn.Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + rows, err := txn.Query("select * from t") + require.NoError(t, err) + cli.CheckRows(t, rows, "1") + + var wg sync.WaitGroup + wg.Add(1) + txn2Locked := make(chan struct{}, 1) + failed := make(chan struct{}, 1) + go func() { + time.Sleep(2 * time.Second) + select { + case <-txn2Locked: + failed <- struct{}{} + default: + } + + err2 := txn.Commit() + require.NoError(t, err2) + wg.Done() + }() + txn2, err := dbt.GetDB().Begin() + require.NoError(t, err) + _, err = txn2.Exec(fmt.Sprintf("USE `%s`;", dbName)) + require.NoError(t, err) + _, err = txn2.Exec("select * from t where a = 1 for update") + require.NoError(t, err) + txn2Locked <- struct{}{} + wg.Wait() + txn2.Rollback() + select { + case <-failed: + require.Fail(t, "txn2 should not be able to acquire the lock") + default: + } + + require.NoError(t, err) + rows = dbt.MustQuery("select * from t") + cli.CheckRows(t, rows, "1") + }, + ) + + dbName = "LoadDataInExplicitTransaction" + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, dbName, func(dbt *testkit.DBTestKit) { + // in optimistic txn, one should not block another + dbt.MustExec("set @@global.tidb_txn_mode = 'optimistic'") + dbt.MustExec("create table t (a int primary key)") + txn1, err := dbt.GetDB().Begin() + require.NoError(t, err) + txn2, err := dbt.GetDB().Begin() + require.NoError(t, err) + _, err = txn1.Exec(fmt.Sprintf("USE `%s`;", dbName)) + require.NoError(t, err) + _, err = txn2.Exec(fmt.Sprintf("USE `%s`;", dbName)) + require.NoError(t, err) + _, err = txn1.Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + _, err = txn2.Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + err = txn1.Commit() + require.NoError(t, err) + err = txn2.Commit() + require.ErrorContains(t, err, "Write conflict") + rows := dbt.MustQuery("select * from t") + cli.CheckRows(t, rows, "1") + }, + ) + + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadDataFromServerFile", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (a int)") + _, err = dbt.GetDB().Exec(fmt.Sprintf("load data infile %q into table t", path)) + require.ErrorContains(t, err, "Don't support load data from tidb-server's disk.") + }, + ) + + // The test is intended to test if the load data statement correctly cleans up its + // resources after execution, and does not affect following statements. + // For example, the 1st load data builds the reader and finishes. + // The 2nd load data should not be able to access the reader, especially when it should fail + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadDataCleanup", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (a int)") + txn, err := dbt.GetDB().Begin() + require.NoError(t, err) + _, err = txn.Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + _, err = txn.Exec("load data local infile '/tmp/does_not_exist' into table t") + require.ErrorContains(t, err, "no such file or directory") + err = txn.Commit() + require.NoError(t, err) + rows := dbt.MustQuery("select * from t") + cli.CheckRows(t, rows, "1") + }, + ) +} + func (cli *TestServerClient) RunTestLoadData(t *testing.T, server *server.Server) { fp, err := os.CreateTemp("", "load_data_test.csv") require.NoError(t, err) diff --git a/pkg/server/tests/tidb_serial_test.go b/pkg/server/tests/tidb_serial_test.go index 8f7e263100528..132703e96fd18 100644 --- a/pkg/server/tests/tidb_serial_test.go +++ b/pkg/server/tests/tidb_serial_test.go @@ -70,6 +70,12 @@ func TestLoadData1(t *testing.T) { ts.RunTestLoadDataForSlowLog(t) } +func TestLoadDataInTransaction(t *testing.T) { + ts := createTidbTestSuite(t) + + ts.RunTestLoadDataInTransaction(t) +} + func TestConfigDefaultValue(t *testing.T) { ts := createTidbTestSuite(t)