Skip to content

Commit

Permalink
planner/core: migrate test-infra to testify for enforce_mpp_test.go (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tisonkun authored Feb 19, 2022
1 parent 1d32049 commit 36a1f84
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 99 deletions.
173 changes: 79 additions & 94 deletions planner/core/enforce_mpp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,22 @@ package core_test

import (
"strings"
"testing"

. "github.com/pingcap/check"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/model"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/testkit/testdata"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testutil"
"github.com/stretchr/testify/require"
)

var _ = SerialSuites(&testEnforceMPPSuite{})

type testEnforceMPPSuite struct {
testData testutil.TestData
store kv.Storage
dom *domain.Domain
}

func (s *testEnforceMPPSuite) SetUpSuite(c *C) {
var err error
s.testData, err = testutil.LoadTestSuiteData("testdata", "enforce_mpp_suite")
c.Assert(err, IsNil)
}

func (s *testEnforceMPPSuite) TearDownSuite(c *C) {
c.Assert(s.testData.GenerateOutputIfNeeded(), IsNil)
}

func (s *testEnforceMPPSuite) SetUpTest(c *C) {
var err error
s.store, s.dom, err = newStoreWithBootstrap()
c.Assert(err, IsNil)
}

func (s *testEnforceMPPSuite) TearDownTest(c *C) {
s.dom.Close()
err := s.store.Close()
c.Assert(err, IsNil)
}

func (s *testEnforceMPPSuite) TestSetVariables(c *C) {
tk := testkit.NewTestKit(c, s.store)
func TestSetVariables(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)

// test value limit of tidb_opt_tiflash_concurrency_factor
tk.MustExec("set @@tidb_opt_tiflash_concurrency_factor = 0")
Expand All @@ -67,18 +40,17 @@ func (s *testEnforceMPPSuite) TestSetVariables(c *C) {

// test set tidb_enforce_mpp when tidb_allow_mpp=false;
err := tk.ExecToErr("set @@tidb_allow_mpp = 0; set @@tidb_enforce_mpp = 1;")
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, `[variable:1231]Variable 'tidb_enforce_mpp' can't be set to the value of '1' but tidb_allow_mpp is 0, please activate tidb_allow_mpp at first.'`)

require.EqualError(t, err, `[variable:1231]Variable 'tidb_enforce_mpp' can't be set to the value of '1' but tidb_allow_mpp is 0, please activate tidb_allow_mpp at first.'`)
err = tk.ExecToErr("set @@tidb_allow_mpp = 1; set @@tidb_enforce_mpp = 1;")
c.Assert(err, IsNil)

require.NoError(t, err)
err = tk.ExecToErr("set @@tidb_allow_mpp = 0;")
c.Assert(err, IsNil)
require.NoError(t, err)
}

func (s *testEnforceMPPSuite) TestEnforceMPP(c *C) {
tk := testkit.NewTestKit(c, s.store)
func TestEnforceMPP(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)

// test query
tk.MustExec("use test")
Expand All @@ -87,10 +59,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPP(c *C) {
tk.MustExec("create index idx on t(a)")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Se)
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
c.Assert(exists, IsTrue)
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "t" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Expand All @@ -106,7 +78,8 @@ func (s *testEnforceMPPSuite) TestEnforceMPP(c *C) {
Plan []string
Warn []string
}
s.testData.GetTestCases(c, &input, &output)
enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData()
enforceMPPSuiteData.GetTestCases(t, &input, &output)
filterWarnings := func(originalWarnings []stmtctx.SQLWarn) []stmtctx.SQLWarn {
warnings := make([]stmtctx.SQLWarn, 0, 4)
for _, warning := range originalWarnings {
Expand All @@ -118,27 +91,29 @@ func (s *testEnforceMPPSuite) TestEnforceMPP(c *C) {
return warnings
}
for i, tt := range input {
s.testData.OnRecord(func() {
testdata.OnRecord(func() {
output[i].SQL = tt
})
if strings.HasPrefix(tt, "set") {
tk.MustExec(tt)
continue
}
s.testData.OnRecord(func() {
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = s.testData.ConvertSQLWarnToStrings(filterWarnings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()))
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = testdata.ConvertSQLWarnToStrings(filterWarnings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
c.Assert(s.testData.ConvertSQLWarnToStrings(filterWarnings(tk.Se.GetSessionVars().StmtCtx.GetWarnings())), DeepEquals, output[i].Warn)
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(filterWarnings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())))
}
}

// general cases.
func (s *testEnforceMPPSuite) TestEnforceMPPWarning1(c *C) {
tk := testkit.NewTestKit(c, s.store)
func TestEnforceMPPWarning1(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)

// test query
tk.MustExec("use test")
Expand All @@ -152,9 +127,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning1(c *C) {
Plan []string
Warn []string
}
s.testData.GetTestCases(c, &input, &output)
enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData()
enforceMPPSuiteData.GetTestCases(t, &input, &output)
for i, tt := range input {
s.testData.OnRecord(func() {
testdata.OnRecord(func() {
output[i].SQL = tt
})
if strings.HasPrefix(tt, "set") {
Expand All @@ -163,10 +139,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning1(c *C) {
}
if strings.HasPrefix(tt, "cmd: create-replica") {
// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Se)
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
c.Assert(exists, IsTrue)
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "t" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Expand All @@ -179,10 +155,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning1(c *C) {
}
if strings.HasPrefix(tt, "cmd: enable-replica") {
// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Se)
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
c.Assert(exists, IsTrue)
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "t" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Expand All @@ -193,31 +169,33 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning1(c *C) {
}
continue
}
s.testData.OnRecord(func() {
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings())
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
c.Assert(s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()), DeepEquals, output[i].Warn)
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
}

// partition table.
func (s *testEnforceMPPSuite) TestEnforceMPPWarning2(c *C) {
tk := testkit.NewTestKit(c, s.store)
func TestEnforceMPPWarning2(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)

// test query
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("CREATE TABLE t (a int, b char(20)) PARTITION BY HASH(a)")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Se)
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
c.Assert(exists, IsTrue)
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "t" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Expand All @@ -233,40 +211,43 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning2(c *C) {
Plan []string
Warn []string
}
s.testData.GetTestCases(c, &input, &output)
enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData()
enforceMPPSuiteData.GetTestCases(t, &input, &output)
for i, tt := range input {
s.testData.OnRecord(func() {
testdata.OnRecord(func() {
output[i].SQL = tt
})
if strings.HasPrefix(tt, "set") {
tk.MustExec(tt)
continue
}
s.testData.OnRecord(func() {
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings())
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
c.Assert(s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()), DeepEquals, output[i].Warn)
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
}

// new collation.
func (s *testEnforceMPPSuite) TestEnforceMPPWarning3(c *C) {
tk := testkit.NewTestKit(c, s.store)
func TestEnforceMPPWarning3(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)

// test query
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("CREATE TABLE t (a int, b char(20))")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Se)
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
c.Assert(exists, IsTrue)
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "t" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Expand All @@ -282,9 +263,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning3(c *C) {
Plan []string
Warn []string
}
s.testData.GetTestCases(c, &input, &output)
enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData()
enforceMPPSuiteData.GetTestCases(t, &input, &output)
for i, tt := range input {
s.testData.OnRecord(func() {
testdata.OnRecord(func() {
output[i].SQL = tt
})
if strings.HasPrefix(tt, "set") || strings.HasPrefix(tt, "UPDATE") {
Expand All @@ -299,21 +281,23 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning3(c *C) {
collate.SetNewCollationEnabledForTest(false)
continue
}
s.testData.OnRecord(func() {
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings())
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
c.Assert(s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()), DeepEquals, output[i].Warn)
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
collate.SetNewCollationEnabledForTest(true)
}

// Test enforce mpp warning for joins
func (s *testEnforceMPPSuite) TestEnforceMPPWarning4(c *C) {
tk := testkit.NewTestKit(c, s.store)
func TestEnforceMPPWarning4(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)

// test table
tk.MustExec("use test")
Expand All @@ -323,10 +307,10 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning4(c *C) {
tk.MustExec("CREATE TABLE s(a int primary key)")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Se)
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
c.Assert(exists, IsTrue)
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "t" || tblInfo.Name.L == "s" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Expand All @@ -342,22 +326,23 @@ func (s *testEnforceMPPSuite) TestEnforceMPPWarning4(c *C) {
Plan []string
Warn []string
}
s.testData.GetTestCases(c, &input, &output)
enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData()
enforceMPPSuiteData.GetTestCases(t, &input, &output)
for i, tt := range input {
s.testData.OnRecord(func() {
testdata.OnRecord(func() {
output[i].SQL = tt
})
if strings.HasPrefix(tt, "set") || strings.HasPrefix(tt, "UPDATE") {
tk.MustExec(tt)
continue
}
s.testData.OnRecord(func() {
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings())
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
c.Assert(s.testData.ConvertSQLWarnToStrings(tk.Se.GetSessionVars().StmtCtx.GetWarnings()), DeepEquals, output[i].Warn)
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
}
7 changes: 6 additions & 1 deletion planner/core/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
"go.uber.org/goleak"
)

var testDataMap = make(testdata.BookKeeper, 5)
var testDataMap = make(testdata.BookKeeper)
var indexMergeSuiteData testdata.TestData

func TestMain(m *testing.M) {
Expand All @@ -38,6 +38,7 @@ func TestMain(m *testing.M) {
testDataMap.LoadTestSuiteData("testdata", "stats_suite")
testDataMap.LoadTestSuiteData("testdata", "ordered_result_mode_suite")
testDataMap.LoadTestSuiteData("testdata", "point_get_plan")
testDataMap.LoadTestSuiteData("testdata", "enforce_mpp_suite")

indexMergeSuiteData = testDataMap["index_merge_suite"]

Expand Down Expand Up @@ -73,3 +74,7 @@ func GetOrderedResultModeSuiteData() testdata.TestData {
func GetPointGetPlanData() testdata.TestData {
return testDataMap["point_get_plan"]
}

func GetEnforceMPPSuiteData() testdata.TestData {
return testDataMap["enforce_mpp_suite"]
}
Loading

0 comments on commit 36a1f84

Please sign in to comment.