From 6e593be4225358ad0bba0a6df8a3220fbb1f3c35 Mon Sep 17 00:00:00 2001 From: you06 Date: Fri, 4 Aug 2023 20:44:39 +0800 Subject: [PATCH] executor: fix load data assertion failure (#43858) (#45825) --- server/server_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++ server/tidb_test.go | 5 +++ table/tables/tables.go | 3 +- 3 files changed, 77 insertions(+), 1 deletion(-) diff --git a/server/server_test.go b/server/server_test.go index 623a4d3313628..b585e8b17bcaf 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2381,3 +2381,73 @@ func (cli *testServerClient) runTestInfoschemaClientErrors(t *testing.T) { } }) } + +func (cli *testServerClient) runTestLoadDataReplace(t *testing.T) { + fp1, err := os.CreateTemp("", "a.dat") + require.NoError(t, err) + require.NotNil(t, fp1) + path1 := fp1.Name() + fp2, err := os.CreateTemp("", "b.dat") + require.NoError(t, err) + require.NotNil(t, fp2) + path2 := fp2.Name() + defer func() { + err = fp1.Close() + require.NoError(t, err) + err = os.Remove(path1) + require.NoError(t, err) + + err = fp2.Close() + require.NoError(t, err) + err = os.Remove(path2) + require.NoError(t, err) + }() + + _, err = fp1.WriteString( + "1,abc\n" + + "2,cdef\n" + + "3,asdf\n") + require.NoError(t, err) + _, err = fp2.WriteString( + "1,AAA\n" + + "2,BBB\n" + + "3,asdf\n" + + "4,444\n") + require.NoError(t, err) + + expects := []struct { + col1 int64 + col2 string + }{ + {1, "AAA"}, + {2, "BBB"}, + {3, "asdf"}, + {4, "444"}, + } + + cli.runTestsOnNewDB(t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t1(id int, name varchar(20), primary key(id) clustered);") + _, err = dbt.GetDB().Exec(fmt.Sprintf(`load data local infile '%s' replace into table t1 fields terminated by ',' enclosed by '' (id,name)`, path1)) + require.NoError(t, err) + _, err = dbt.GetDB().Exec(fmt.Sprintf(`load data local infile '%s' replace into table t1 fields terminated by ',' enclosed by '' (id,name)`, path2)) + require.NoError(t, err) + var ( + a sql.NullInt64 + b sql.NullString + ) + rows := dbt.MustQuery("select * from t1 order by id asc") + for _, expect := range expects { + require.Truef(t, rows.Next(), "unexpected data") + err = rows.Scan(&a, &b) + require.NoError(t, err) + require.Equal(t, expect.col1, a.Int64) + require.Equal(t, expect.col2, b.String) + err = rows.Scan(&a, &b) + require.NoError(t, err) + } + require.Falsef(t, rows.Next(), "expect end") + }) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index a9cf74347a56f..30e4ea814927f 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -3210,3 +3210,8 @@ func TestProxyProtocolWithIpNoFallbackable(t *testing.T) { require.NotNil(t, err) db.Close() } + +func TestLoadData(t *testing.T) { + ts := createTidbTestSuite(t) + ts.runTestLoadDataReplace(t) +} diff --git a/table/tables/tables.go b/table/tables/tables.go index da20b1647fbd8..a18896ce07f94 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -882,7 +882,8 @@ func (t *TableCommon) AddRecord(sctx sessionctx.Context, r []types.Datum, opts . } } }) - if setPresume && !txn.IsPessimistic() { + // batch-check guarantees the existence of key itself. + if setPresume && !txn.IsPessimistic() || sctx.GetSessionVars().StmtCtx.BatchCheck { err = txn.SetAssertion(key, kv.SetAssertUnknown) } else { err = txn.SetAssertion(key, kv.SetAssertNotExist)