Skip to content

Commit

Permalink
planner: add more test cases for MPP hints (#39933)
Browse files Browse the repository at this point in the history
  • Loading branch information
Reminiscent authored Dec 21, 2022
1 parent 2150c6b commit 785f515
Show file tree
Hide file tree
Showing 5 changed files with 984 additions and 9 deletions.
32 changes: 31 additions & 1 deletion planner/core/exhaust_physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -1928,7 +1928,7 @@ func (p *LogicalJoin) exhaustPhysicalPlans(prop *property.PhysicalProperty) ([]P
}
})

if (p.preferJoinType&preferBCJoin) == 0 && p.preferJoinType > 0 {
if (p.preferJoinType&preferBCJoin) == 0 && (p.preferJoinType&preferShuffleJoin) == 0 && p.preferJoinType > 0 {
p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because you have used hint to specify a join algorithm which is not supported by mpp now.")
if prop.IsFlashProp() {
return nil, false, nil
Expand Down Expand Up @@ -1957,6 +1957,21 @@ func (p *LogicalJoin) exhaustPhysicalPlans(prop *property.PhysicalProperty) ([]P
mppJoins := p.tryToGetMppHashJoin(prop, false)
joins = append(joins, mppJoins...)
}
} else {
hasMppHints := false
var errMsg string
if (p.preferJoinType & preferShuffleJoin) > 0 {
errMsg = "The join can not push down to the MPP side, the shuffle_join() hint is invalid"
hasMppHints = true
}
if (p.preferJoinType & preferBCJoin) > 0 {
errMsg = "The join can not push down to the MPP side, the broadcast_join() hint is invalid"
hasMppHints = true
}
if hasMppHints {
warning := ErrInternal.GenWithStack(errMsg)
p.ctx.GetSessionVars().StmtCtx.AppendWarning(warning)
}
}
if prop.IsFlashProp() {
return joins, true, nil
Expand Down Expand Up @@ -2858,6 +2873,21 @@ func (la *LogicalAggregation) getHashAggs(prop *property.PhysicalProperty) []Phy
}
if canPushDownToMPP {
taskTypes = append(taskTypes, property.MppTaskType)
} else {
hasMppHints := false
var errMsg string
if la.aggHints.preferAggType&preferMPP1PhaseAgg > 0 {
errMsg = "The agg can not push down to the MPP side, the MPP_1PHASE_AGG() hint is invalid"
hasMppHints = true
}
if la.aggHints.preferAggType&preferMPP2PhaseAgg > 0 {
errMsg = "The agg can not push down to the MPP side, the MPP_2PHASE_AGG() hint is invalid"
hasMppHints = true
}
if hasMppHints {
warning := ErrInternal.GenWithStack(errMsg)
la.ctx.GetSessionVars().StmtCtx.AppendWarning(warning)
}
}
if prop.IsFlashProp() {
taskTypes = []property.TaskType{prop.TaskTp}
Expand Down
95 changes: 94 additions & 1 deletion planner/core/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,9 +813,11 @@ func TestMPPHints(t *testing.T) {
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("set tidb_cost_model_version=2")
tk.MustExec("create table t (a int, b int, c int)")
tk.MustExec("create table t (a int, b int, c int, index idx_a(a), index idx_b(b))")
tk.MustExec("alter table t set tiflash replica 1")
tk.MustExec("set @@session.tidb_allow_mpp=ON")
tk.MustExec("create definer='root'@'localhost' view v as select a, sum(b) from t group by a, c;")
tk.MustExec("create definer='root'@'localhost' view v1 as select t1.a from t t1, t t2 where t1.a=t2.a;")
tb := external.GetTableByName(t, tk, "test", "t")
err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)
Expand All @@ -824,6 +826,7 @@ func TestMPPHints(t *testing.T) {
var output []struct {
SQL string
Plan []string
Warn []string
}

planSuiteData := core.GetPlanSuiteData()
Expand All @@ -833,11 +836,101 @@ func TestMPPHints(t *testing.T) {
testdata.OnRecord(func() {
output[i].SQL = ts
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery("explain format = 'brief' " + ts).Rows())
output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())
})
tk.MustQuery("explain format = 'brief' " + ts).Check(testkit.Rows(output[i].Plan...))
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
}

func TestMPPHintsScope(t *testing.T) {
store := testkit.CreateMockStore(t, withMockTiFlash(2))
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("set tidb_cost_model_version=2")
tk.MustExec("create table t (a int, b int, c int, index idx_a(a), index idx_b(b))")
tk.MustExec("select /*+ MPP_1PHASE_AGG() */ a, sum(b) from t group by a, c")
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1815 The agg can not push down to the MPP side, the MPP_1PHASE_AGG() hint is invalid"))
tk.MustExec("select /*+ MPP_2PHASE_AGG() */ a, sum(b) from t group by a, c")
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1815 The agg can not push down to the MPP side, the MPP_2PHASE_AGG() hint is invalid"))
tk.MustExec("select /*+ shuffle_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a")
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1815 The join can not push down to the MPP side, the shuffle_join() hint is invalid"))
tk.MustExec("select /*+ broadcast_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a")
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1815 The join can not push down to the MPP side, the broadcast_join() hint is invalid"))
tk.MustExec("alter table t set tiflash replica 1")
tb := external.GetTableByName(t, tk, "test", "t")
err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)
tk.MustExec("set @@session.tidb_allow_mpp=true")
tk.MustExec("explain select /*+ MPP_1PHASE_AGG() */ a, sum(b) from t group by a, c")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustExec("explain select /*+ MPP_2PHASE_AGG() */ a, sum(b) from t group by a, c")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustExec("explain select /*+ shuffle_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustExec("explain select /*+ broadcast_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a")
tk.MustQuery("show warnings").Check(testkit.Rows())

tk.MustExec("set @@session.tidb_allow_mpp=false")
tk.MustExec("explain select /*+ MPP_1PHASE_AGG() */ a, sum(b) from t group by a, c")
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1815 The agg can not push down to the MPP side, the MPP_1PHASE_AGG() hint is invalid"))
tk.MustExec("explain select /*+ MPP_2PHASE_AGG() */ a, sum(b) from t group by a, c")
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1815 The agg can not push down to the MPP side, the MPP_2PHASE_AGG() hint is invalid"))
tk.MustExec("explain select /*+ shuffle_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a")
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1815 The join can not push down to the MPP side, the shuffle_join() hint is invalid"))
tk.MustExec("explain select /*+ broadcast_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a")
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1815 The join can not push down to the MPP side, the broadcast_join() hint is invalid"))
}

func TestMPPHintsWithBinding(t *testing.T) {
store := testkit.CreateMockStore(t, withMockTiFlash(2))
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("set tidb_cost_model_version=2")
tk.MustExec("create table t (a int, b int, c int)")
tk.MustExec("alter table t set tiflash replica 1")
tk.MustExec("set @@session.tidb_allow_mpp=ON")
tb := external.GetTableByName(t, tk, "test", "t")
err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)

tk.MustExec("explain select a, sum(b) from t group by a, c")
tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("0"))
tk.MustExec("create global binding for select a, sum(b) from t group by a, c using select /*+ read_from_storage(tiflash[t]), MPP_1PHASE_AGG() */ a, sum(b) from t group by a, c;")
tk.MustExec("explain select a, sum(b) from t group by a, c")
tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("1"))
res := tk.MustQuery("show global bindings").Rows()
require.Equal(t, res[0][0], "select `a` , sum ( `b` ) from `test` . `t` group by `a` , `c`")
require.Equal(t, res[0][1], "SELECT /*+ read_from_storage(tiflash[`t`]) MPP_1PHASE_AGG()*/ `a`,sum(`b`) FROM `test`.`t` GROUP BY `a`,`c`")
tk.MustExec("create global binding for select a, sum(b) from t group by a, c using select /*+ read_from_storage(tiflash[t]), MPP_2PHASE_AGG() */ a, sum(b) from t group by a, c;")
tk.MustExec("explain select a, sum(b) from t group by a, c")
tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("1"))
res = tk.MustQuery("show global bindings").Rows()
require.Equal(t, res[0][0], "select `a` , sum ( `b` ) from `test` . `t` group by `a` , `c`")
require.Equal(t, res[0][1], "SELECT /*+ read_from_storage(tiflash[`t`]) MPP_2PHASE_AGG()*/ `a`,sum(`b`) FROM `test`.`t` GROUP BY `a`,`c`")
tk.MustExec("drop global binding for select a, sum(b) from t group by a, c;")
res = tk.MustQuery("show global bindings").Rows()
require.Equal(t, len(res), 0)

tk.MustExec("explain select * from t t1, t t2 where t1.a=t2.a")
tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("0"))
tk.MustExec("create global binding for select * from t t1, t t2 where t1.a=t2.a using select /*+ read_from_storage(tiflash[t1, t2]), shuffle_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a")
tk.MustExec("explain select * from t t1, t t2 where t1.a=t2.a")
tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("1"))
res = tk.MustQuery("show global bindings").Rows()
require.Equal(t, res[0][0], "select * from ( `test` . `t` as `t1` ) join `test` . `t` as `t2` where `t1` . `a` = `t2` . `a`")
require.Equal(t, res[0][1], "SELECT /*+ read_from_storage(tiflash[`t1`, `t2`]) shuffle_join(`t1`, `t2`)*/ * FROM (`test`.`t` AS `t1`) JOIN `test`.`t` AS `t2` WHERE `t1`.`a` = `t2`.`a`")
tk.MustExec("create global binding for select * from t t1, t t2 where t1.a=t2.a using select /*+ read_from_storage(tiflash[t1, t2]), broadcast_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a;")
tk.MustExec("explain select * from t t1, t t2 where t1.a=t2.a")
tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("1"))
res = tk.MustQuery("show global bindings").Rows()
require.Equal(t, res[0][0], "select * from ( `test` . `t` as `t1` ) join `test` . `t` as `t2` where `t1` . `a` = `t2` . `a`")
require.Equal(t, res[0][1], "SELECT /*+ read_from_storage(tiflash[`t1`, `t2`]) broadcast_join(`t1`, `t2`)*/ * FROM (`test`.`t` AS `t1`) JOIN `test`.`t` AS `t2` WHERE `t1`.`a` = `t2`.`a`")
tk.MustExec("drop global binding for select * from t t1, t t2 where t1.a=t2.a;")
res = tk.MustQuery("show global bindings").Rows()
require.Equal(t, len(res), 0)
}

func TestHintScope(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
6 changes: 3 additions & 3 deletions planner/core/testdata/integration_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -6229,9 +6229,9 @@
{
"SQL": "explain format = 'brief' select /*+ qb_name(qb, v12) read_from_storage(tiflash[t1@qb, t@qb]), shuffle_join(t1@qb, t@qb) */ * from v12;",
"Plan": [
"Projection 12500.00 root test.t.a, test.t.b",
"└─TableReader 12500.00 root data:ExchangeSender",
" └─ExchangeSender 12500.00 mpp[tiflash] ExchangeType: PassThrough",
"TableReader 12500.00 root data:ExchangeSender",
"└─ExchangeSender 12500.00 mpp[tiflash] ExchangeType: PassThrough",
" └─Projection 12500.00 mpp[tiflash] test.t.a, test.t.b",
" └─HashJoin 12500.00 mpp[tiflash] inner join, equal:[eq(test.t.a, test.t.a)]",
" ├─ExchangeReceiver(Build) 10000.00 mpp[tiflash] ",
" │ └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary]",
Expand Down
62 changes: 61 additions & 1 deletion planner/core/testdata/plan_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,70 @@
{
"name": "TestMPPHints",
"cases": [
"select /*+ MPP_1PHASE_AGG() */ a, sum(b) from t group by a, c",
"select /*+ MPP_2PHASE_AGG() */ a, sum(b) from t group by a, c",
"select /*+ shuffle_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a",
"select /*+ broadcast_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a",

// READ_FROM_STORAGE hint
"select /*+ read_from_storage(tiflash[t]), MPP_1PHASE_AGG() */ a, sum(b) from t group by a, c",
"select /*+ read_from_storage(tiflash[t]), MPP_2PHASE_AGG() */ a, sum(b) from t group by a, c",
"select /*+ read_from_storage(tiflash[t1, t2]), shuffle_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a",
"select /*+ read_from_storage(tiflash[t1, t2]), broadcast_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a"
"select /*+ read_from_storage(tiflash[t1, t2]), broadcast_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a",

// Join hint
"select /*+ read_from_storage(tiflash[t1, t2]), shuffle_join(t1, t2), hash_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a",
"select /*+ read_from_storage(tiflash[t1, t2]), broadcast_join(t1, t2), hash_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a",

"select /*+ read_from_storage(tiflash[t1, t2]), shuffle_join(t1, t2), hash_join_build(t1) */ * from t t1, t t2 where t1.a=t2.a",
"select /*+ read_from_storage(tiflash[t1, t2]), broadcast_join(t1, t2), hash_join_build(t2) */ * from t t1, t t2 where t1.a=t2.a",

"select /*+ read_from_storage(tiflash[t1, t2]), shuffle_join(t1, t2), hash_join_probe(t1) */ * from t t1, t t2 where t1.a=t2.a",
"select /*+ read_from_storage(tiflash[t1, t2]), broadcast_join(t1, t2), hash_join_probe(t2) */ * from t t1, t t2 where t1.a=t2.a",

"select /*+ read_from_storage(tiflash[t1, t2]), shuffle_join(t1, t2), merge_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a",
"select /*+ read_from_storage(tiflash[t1, t2]), broadcast_join(t1, t2), merge_join(t1, t2) */ * from t t1, t t2 where t1.a=t2.a",

"select /*+ read_from_storage(tiflash[t1, t2]), shuffle_join(t1, t2), INL_JOIN(t1, t2) */ * from t t1, t t2 where t1.a=t2.a",
"select /*+ read_from_storage(tiflash[t1, t2]), broadcast_join(t1, t2), INL_JOIN(t1, t2) */ * from t t1, t t2 where t1.a=t2.a",

// AGG hint
"select /*+ read_from_storage(tiflash[t]), MPP_1PHASE_AGG(), hash_agg() */ a, sum(b) from t group by a, c",
"select /*+ read_from_storage(tiflash[t]), MPP_2PHASE_AGG(), stream_agg() */ a, sum(b) from t group by a, c",

// Index hint
"select /*+ read_from_storage(tiflash[t]), MPP_1PHASE_AGG(), use_index(t, idx_a) */ a, sum(b) from t where a > 1 group by a, c",
"select /*+ read_from_storage(tiflash[t]), MPP_1PHASE_AGG(), ignore_index(t, idx_a) */ a, sum(b) from t where a > 1 group by a, c",
"select /*+ read_from_storage(tiflash[t]), MPP_2PHASE_AGG(), force_index(t, idx_b) */ a, sum(b) from t where b < 2 group by a, c",
"select /*+ read_from_storage(tiflash[t]), MPP_2PHASE_AGG(), index_merge(t, idx_b, idx_a) */ a, sum(b) from t where b < 2 or a > 2 group by a, c",

// Join Order hint
"select /*+ read_from_storage(tiflash[t1, t2, t3]), shuffle_join(t1, t2, t3), straight_join() */ * from t t1, t t2, t t3 where t1.a=t2.a and t2.b=t3.b",
"select /*+ read_from_storage(tiflash[t1, t2, t3]), shuffle_join(t1, t2, t3), leading(t3, t1) */ * from t t1, t t2, t t3 where t1.a=t2.a and t2.b=t3.b",
"select /*+ read_from_storage(tiflash[t1, t2, t3]), broadcast_join(t1, t2, t3), straight_join() */ * from t t2, t t1, t t3 where t1.a=t2.a and t2.b=t3.b",
"select /*+ read_from_storage(tiflash[t1, t2, t3]), broadcast_join(t1, t2, t3), leading(t2, t3) */ * from t t1, t t2, t t3 where t1.a=t2.a and t2.b=t3.b",

// View Hint
"select /*+ qb_name(qb, v), MPP_1PHASE_AGG(@qb) */ * from v",
"select /*+ qb_name(qb, v), MPP_2PHASE_AGG(@qb) */ * from v",
"select /*+ qb_name(qb, v1), shuffle_join(t1@qb, t2@qb) */ * from v1",
"select /*+ qb_name(qb, v1), broadcast_join(t1@qb, t2@qb) */ * from v1",

// Subquery hint
"SELECT /*+ shuffle_join(t) */ * FROM t WHERE EXISTS (SELECT /*+ SEMI_JOIN_REWRITE */ 1 FROM t t1 WHERE t1.b = t.b);",
"SELECT /*+ broadcast_join(t) */ * FROM t WHERE EXISTS (SELECT /*+ SEMI_JOIN_REWRITE */ 1 FROM t t1 WHERE t1.b = t.b);",
"select * from t t1 where t1.a < (select /*+ MPP_1PHASE_AGG() */ sum(t2.a) from t t2 where t2.b = t1.b);",
"select * from t t1 where t1.a < (select /*+ MPP_2PHASE_AGG() */ sum(t2.a) from t t2 where t2.b = t1.b);",

// CTE
"WITH CTE AS (SELECT /*+ MPP_1PHASE_AGG() */ count(*) as a, b FROM t WHERE t.a < 60 group by b) SELECT * FROM CTE WHERE CTE.a <18 union select * from cte where cte.b > 1;",
"WITH CTE AS (SELECT /*+ MPP_2PHASE_AGG() */ count(*) as a, b FROM t WHERE t.a < 60 group by b) SELECT * FROM CTE WHERE CTE.a <18 union select * from cte where cte.b > 1;",
"WITH CTE AS (SELECT /*+ shuffle_join(t1, t) */ t.a, t.b FROM t join t t1 where t.a = t1.a) SELECT * FROM CTE WHERE CTE.a <18 union select * from cte where cte.b > 1;",
"WITH CTE AS (SELECT /*+ broadcast_join(t1, t) */ t.a, t.b FROM t join t t1 where t.a = t1.a) SELECT * FROM CTE WHERE CTE.a <18 union select * from cte where cte.b > 1;",
"WITH CTE AS (SELECT /*+ MERGE(), MPP_1PHASE_AGG() */ count(*) as a, b FROM t WHERE t.a < 60 group by b) SELECT * FROM CTE WHERE CTE.a <18 union select * from cte where cte.b > 1;",
"WITH CTE AS (SELECT /*+ MERGE(), MPP_2PHASE_AGG() */ count(*) as a, b FROM t WHERE t.a < 60 group by b) SELECT * FROM CTE WHERE CTE.a <18 union select * from cte where cte.b > 1;",
"WITH CTE AS (SELECT /*+ MERGE(), shuffle_join(t1, t) */ t.a, t.b FROM t join t t1 where t.a = t1.a) SELECT * FROM CTE WHERE CTE.a <18 union select * from cte where cte.b > 1;",
"WITH CTE AS (SELECT /*+ MERGE(), broadcast_join(t1, t) */ t.a, t.b FROM t join t t1 where t.a = t1.a) SELECT * FROM CTE WHERE CTE.a <18 union select * from cte where cte.b > 1;"
]
},
{
Expand Down
Loading

0 comments on commit 785f515

Please sign in to comment.