diff --git a/planner/core/enforce_mpp_test.go b/planner/core/enforce_mpp_test.go index 4ee16565a9c8a..8cc599d075d28 100644 --- a/planner/core/enforce_mpp_test.go +++ b/planner/core/enforce_mpp_test.go @@ -16,49 +16,22 @@ package core_test import ( "strings" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/tidb/domain" - "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/model" + plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/testkit/testdata" "github.com/pingcap/tidb/util/collate" - "github.com/pingcap/tidb/util/testkit" - "github.com/pingcap/tidb/util/testutil" + "github.com/stretchr/testify/require" ) -var _ = SerialSuites(&testEnforceMPPSuite{}) - -type testEnforceMPPSuite struct { - testData testutil.TestData - store kv.Storage - dom *domain.Domain -} - -func (s *testEnforceMPPSuite) SetUpSuite(c *C) { - var err error - s.testData, err = testutil.LoadTestSuiteData("testdata", "enforce_mpp_suite") - c.Assert(err, IsNil) -} - -func (s *testEnforceMPPSuite) TearDownSuite(c *C) { - c.Assert(s.testData.GenerateOutputIfNeeded(), IsNil) -} - -func (s *testEnforceMPPSuite) SetUpTest(c *C) { - var err error - s.store, s.dom, err = newStoreWithBootstrap() - c.Assert(err, IsNil) -} - -func (s *testEnforceMPPSuite) TearDownTest(c *C) { - s.dom.Close() - err := s.store.Close() - c.Assert(err, IsNil) -} - -func (s *testEnforceMPPSuite) TestSetVariables(c *C) { - tk := testkit.NewTestKit(c, s.store) +func TestSetVariables(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) // test value limit of tidb_opt_tiflash_concurrency_factor tk.MustExec("set @@tidb_opt_tiflash_concurrency_factor = 0") @@ -67,18 +40,17 @@ func (s *testEnforceMPPSuite) TestSetVariables(c *C) { // test set tidb_enforce_mpp when tidb_allow_mpp=false; err := tk.ExecToErr("set @@tidb_allow_mpp = 0; set @@tidb_enforce_mpp = 1;") - c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, `[variable:1231]Variable 'tidb_enforce_mpp' can't be set to the value of '1' but tidb_allow_mpp is 0, please activate tidb_allow_mpp at first.'`) - + require.EqualError(t, err, `[variable:1231]Variable 'tidb_enforce_mpp' can't be set to the value of '1' but tidb_allow_mpp is 0, please activate tidb_allow_mpp at first.'`) err = tk.ExecToErr("set @@tidb_allow_mpp = 1; set @@tidb_enforce_mpp = 1;") - c.Assert(err, IsNil) - + require.NoError(t, err) err = tk.ExecToErr("set @@tidb_allow_mpp = 0;") - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *testEnforceMPPSuite) TestEnforceMPP(c *C) { - tk := testkit.NewTestKit(c, s.store) +func TestEnforceMPP(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) // test query tk.MustExec("use test") @@ -87,10 +59,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPP(c *C) { tk.MustExec("create index idx on t(a)") // Create virtual tiflash replica info. - dom := domain.GetDomain(tk.Se) + dom := domain.GetDomain(tk.Session()) is := dom.InfoSchema() db, exists := is.SchemaByName(model.NewCIStr("test")) - c.Assert(exists, IsTrue) + require.True(t, exists) for _, tblInfo := range db.Tables { if tblInfo.Name.L == "t" { tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ @@ -106,7 +78,8 @@ func (s *testEnforceMPPSuite) TestEnforceMPP(c *C) { Plan []string Warn []string } - s.testData.GetTestCases(c, &input, &output) + enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData() + enforceMPPSuiteData.GetTestCases(t, &input, &output) filterWarnings := func(originalWarnings []stmtctx.SQLWarn) []stmtctx.SQLWarn { warnings := make([]stmtctx.SQLWarn, 0, 4) for _, warning := range originalWarnings { @@ -118,27 +91,29 @@ func (s *testEnforceMPPSuite) TestEnforceMPP(c *C) { return warnings } for i, tt := range input { - s.testData.OnRecord(func() { + testdata.OnRecord(func() { output[i].SQL = tt }) if strings.HasPrefix(tt, "set") { tk.MustExec(tt) continue } - s.testData.OnRecord(func() { + testdata.OnRecord(func() { output[i].SQL = tt - output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) - output[i].Warn = s.testData.ConvertSQLWarnToStrings(filterWarnings(tk.Se.GetSessionVars().StmtCtx.GetWarnings())) + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + output[i].Warn = testdata.ConvertSQLWarnToStrings(filterWarnings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())) }) res := tk.MustQuery(tt) res.Check(testkit.Rows(output[i].Plan...)) - c.Assert(s.testData.ConvertSQLWarnToStrings(filterWarnings(tk.Se.GetSessionVars().StmtCtx.GetWarnings())), DeepEquals, output[i].Warn) + require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(filterWarnings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))) } } // general cases. -func (s *testEnforceMPPSuite) TestEnforceMPPWarning1(c *C) { - tk := testkit.NewTestKit(c, s.store) +func TestEnforceMPPWarning1(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) // test query tk.MustExec("use test") @@ -152,9 +127,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning1(c *C) { Plan []string Warn []string } - s.testData.GetTestCases(c, &input, &output) + enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData() + enforceMPPSuiteData.GetTestCases(t, &input, &output) for i, tt := range input { - s.testData.OnRecord(func() { + testdata.OnRecord(func() { output[i].SQL = tt }) if strings.HasPrefix(tt, "set") { @@ -163,10 +139,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning1(c *C) { } if strings.HasPrefix(tt, "cmd: create-replica") { // Create virtual tiflash replica info. - dom := domain.GetDomain(tk.Se) + dom := domain.GetDomain(tk.Session()) is := dom.InfoSchema() db, exists := is.SchemaByName(model.NewCIStr("test")) - c.Assert(exists, IsTrue) + require.True(t, exists) for _, tblInfo := range db.Tables { if tblInfo.Name.L == "t" { tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ @@ -179,10 +155,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning1(c *C) { } if strings.HasPrefix(tt, "cmd: enable-replica") { // Create virtual tiflash replica info. - dom := domain.GetDomain(tk.Se) + dom := domain.GetDomain(tk.Session()) is := dom.InfoSchema() db, exists := is.SchemaByName(model.NewCIStr("test")) - c.Assert(exists, IsTrue) + require.True(t, exists) for _, tblInfo := range db.Tables { if tblInfo.Name.L == "t" { tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ @@ -193,20 +169,22 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning1(c *C) { } continue } - s.testData.OnRecord(func() { + testdata.OnRecord(func() { output[i].SQL = tt - output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) - output[i].Warn = s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()) + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()) }) res := tk.MustQuery(tt) res.Check(testkit.Rows(output[i].Plan...)) - c.Assert(s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()), DeepEquals, output[i].Warn) + require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())) } } // partition table. -func (s *testEnforceMPPSuite) TestEnforceMPPWarning2(c *C) { - tk := testkit.NewTestKit(c, s.store) +func TestEnforceMPPWarning2(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) // test query tk.MustExec("use test") @@ -214,10 +192,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning2(c *C) { tk.MustExec("CREATE TABLE t (a int, b char(20)) PARTITION BY HASH(a)") // Create virtual tiflash replica info. - dom := domain.GetDomain(tk.Se) + dom := domain.GetDomain(tk.Session()) is := dom.InfoSchema() db, exists := is.SchemaByName(model.NewCIStr("test")) - c.Assert(exists, IsTrue) + require.True(t, exists) for _, tblInfo := range db.Tables { if tblInfo.Name.L == "t" { tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ @@ -233,29 +211,32 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning2(c *C) { Plan []string Warn []string } - s.testData.GetTestCases(c, &input, &output) + enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData() + enforceMPPSuiteData.GetTestCases(t, &input, &output) for i, tt := range input { - s.testData.OnRecord(func() { + testdata.OnRecord(func() { output[i].SQL = tt }) if strings.HasPrefix(tt, "set") { tk.MustExec(tt) continue } - s.testData.OnRecord(func() { + testdata.OnRecord(func() { output[i].SQL = tt - output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) - output[i].Warn = s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()) + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()) }) res := tk.MustQuery(tt) res.Check(testkit.Rows(output[i].Plan...)) - c.Assert(s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()), DeepEquals, output[i].Warn) + require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())) } } // new collation. -func (s *testEnforceMPPSuite) TestEnforceMPPWarning3(c *C) { - tk := testkit.NewTestKit(c, s.store) +func TestEnforceMPPWarning3(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) // test query tk.MustExec("use test") @@ -263,10 +244,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning3(c *C) { tk.MustExec("CREATE TABLE t (a int, b char(20))") // Create virtual tiflash replica info. - dom := domain.GetDomain(tk.Se) + dom := domain.GetDomain(tk.Session()) is := dom.InfoSchema() db, exists := is.SchemaByName(model.NewCIStr("test")) - c.Assert(exists, IsTrue) + require.True(t, exists) for _, tblInfo := range db.Tables { if tblInfo.Name.L == "t" { tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ @@ -282,9 +263,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning3(c *C) { Plan []string Warn []string } - s.testData.GetTestCases(c, &input, &output) + enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData() + enforceMPPSuiteData.GetTestCases(t, &input, &output) for i, tt := range input { - s.testData.OnRecord(func() { + testdata.OnRecord(func() { output[i].SQL = tt }) if strings.HasPrefix(tt, "set") || strings.HasPrefix(tt, "UPDATE") { @@ -299,21 +281,23 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning3(c *C) { collate.SetNewCollationEnabledForTest(false) continue } - s.testData.OnRecord(func() { + testdata.OnRecord(func() { output[i].SQL = tt - output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) - output[i].Warn = s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()) + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()) }) res := tk.MustQuery(tt) res.Check(testkit.Rows(output[i].Plan...)) - c.Assert(s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()), DeepEquals, output[i].Warn) + require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())) } collate.SetNewCollationEnabledForTest(true) } // Test enforce mpp warning for joins -func (s *testEnforceMPPSuite) TestEnforceMPPWarning4(c *C) { - tk := testkit.NewTestKit(c, s.store) +func TestEnforceMPPWarning4(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) // test table tk.MustExec("use test") @@ -323,10 +307,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning4(c *C) { tk.MustExec("CREATE TABLE s(a int primary key)") // Create virtual tiflash replica info. - dom := domain.GetDomain(tk.Se) + dom := domain.GetDomain(tk.Session()) is := dom.InfoSchema() db, exists := is.SchemaByName(model.NewCIStr("test")) - c.Assert(exists, IsTrue) + require.True(t, exists) for _, tblInfo := range db.Tables { if tblInfo.Name.L == "t" || tblInfo.Name.L == "s" { tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ @@ -342,22 +326,23 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning4(c *C) { Plan []string Warn []string } - s.testData.GetTestCases(c, &input, &output) + enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData() + enforceMPPSuiteData.GetTestCases(t, &input, &output) for i, tt := range input { - s.testData.OnRecord(func() { + testdata.OnRecord(func() { output[i].SQL = tt }) if strings.HasPrefix(tt, "set") || strings.HasPrefix(tt, "UPDATE") { tk.MustExec(tt) continue } - s.testData.OnRecord(func() { + testdata.OnRecord(func() { output[i].SQL = tt - output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) - output[i].Warn = s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()) + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()) }) res := tk.MustQuery(tt) res.Check(testkit.Rows(output[i].Plan...)) - c.Assert(s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()), DeepEquals, output[i].Warn) + require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())) } } diff --git a/planner/core/main_test.go b/planner/core/main_test.go index 926b9ad2d7ee5..af18bc6621a50 100644 --- a/planner/core/main_test.go +++ b/planner/core/main_test.go @@ -24,7 +24,7 @@ import ( "go.uber.org/goleak" ) -var testDataMap = make(testdata.BookKeeper, 5) +var testDataMap = make(testdata.BookKeeper) var indexMergeSuiteData testdata.TestData func TestMain(m *testing.M) { @@ -38,6 +38,7 @@ func TestMain(m *testing.M) { testDataMap.LoadTestSuiteData("testdata", "stats_suite") testDataMap.LoadTestSuiteData("testdata", "ordered_result_mode_suite") testDataMap.LoadTestSuiteData("testdata", "point_get_plan") + testDataMap.LoadTestSuiteData("testdata", "enforce_mpp_suite") indexMergeSuiteData = testDataMap["index_merge_suite"] @@ -73,3 +74,7 @@ func GetOrderedResultModeSuiteData() testdata.TestData { func GetPointGetPlanData() testdata.TestData { return testDataMap["point_get_plan"] } + +func GetEnforceMPPSuiteData() testdata.TestData { + return testDataMap["enforce_mpp_suite"] +} diff --git a/planner/core/point_get_plan_test.go b/planner/core/point_get_plan_test.go index e544abfc47dfc..ba83f6f0d384c 100644 --- a/planner/core/point_get_plan_test.go +++ b/planner/core/point_get_plan_test.go @@ -339,8 +339,8 @@ func TestCBOPointGet(t *testing.T) { Plan []string Res []string } - statsSuiteData := core.GetPointGetPlanData() - statsSuiteData.GetTestCases(t, &input, &output) + pointGetPlanData := core.GetPointGetPlanData() + pointGetPlanData.GetTestCases(t, &input, &output) require.Equal(t, len(input), len(output)) for i, sql := range input { plan := tk.MustQuery("explain format = 'brief' " + sql) @@ -827,8 +827,8 @@ func TestCBOShouldNotUsePointGet(t *testing.T) { Res []string } - statsSuiteData := core.GetPointGetPlanData() - statsSuiteData.GetTestCases(t, &input, &output) + pointGetPlanData := core.GetPointGetPlanData() + pointGetPlanData.GetTestCases(t, &input, &output) require.Equal(t, len(input), len(output)) for i, sql := range input { plan := tk.MustQuery("explain format = 'brief' " + sql) diff --git a/testkit/testdata/testdata.go b/testkit/testdata/testdata.go index 41ecd53cbf9e2..8ab2f343d7b8e 100644 --- a/testkit/testdata/testdata.go +++ b/testkit/testdata/testdata.go @@ -32,6 +32,7 @@ import ( "testing" "github.com/pingcap/errors" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/stretchr/testify/require" ) @@ -124,6 +125,14 @@ func ConvertRowsToStrings(rows [][]interface{}) (rs []string) { return rs } +// ConvertSQLWarnToStrings converts []SQLWarn to []string. +func ConvertSQLWarnToStrings(warns []stmtctx.SQLWarn) (rs []string) { + for _, warn := range warns { + rs = append(rs, fmt.Sprint(warn.Err.Error())) + } + return rs +} + // GetTestCases gets the test cases for a test function. func (td *TestData) GetTestCases(t *testing.T, in interface{}, out interface{}) { // Extract caller's name.