Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: refactor Converting Partition Keys for shuffle hash join (#24456) #24490

Merged
merged 4 commits into from
May 8, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2943,16 +2943,19 @@ func (s *testIntegrationSerialSuite) TestMppJoinDecimal(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("drop table if exists tt")
tk.MustExec("create table t (c1 decimal(8, 5), c2 decimal(9, 5), c3 decimal(9, 4) NOT NULL, c4 decimal(8, 4) NOT NULL, c5 decimal(40, 20))")
tk.MustExec("create table tt (pk int(11) NOT NULL AUTO_INCREMENT primary key,col_varchar_64 varchar(64),col_char_64_not_null char(64) NOT null, col_decimal_30_10_key decimal(30,10), col_tinyint tinyint, col_varchar_key varchar(1), key col_decimal_30_10_key (col_decimal_30_10_key), key col_varchar_key(col_varchar_key));")
tk.MustExec("analyze table t")
tk.MustExec("analyze table tt")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Se)
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
c.Assert(exists, IsTrue)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "t" {
if tblInfo.Name.L == "t" || tblInfo.Name.L == "tt" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
Expand Down
52 changes: 25 additions & 27 deletions planner/core/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,9 @@ func appendExpr(p *PhysicalProjection, expr expression.Expression) *expression.C
return col
}

// TiFlash join require that join key has exactly the same type, while TiDB only guarantee the join key is the same catalog,
// so if the join key type is not exactly the same, we need add a projection below the join or exchanger if exists.
func (p *PhysicalHashJoin) convertJoinKeyForTiFlashIfNeed(lTask, rTask *mppTask) (*mppTask, *mppTask) {
// TiFlash join require that partition key has exactly the same type, while TiDB only guarantee the partition key is the same catalog,
// so if the partition key type is not exactly the same, we need add a projection below the join or exchanger if exists.
func (p *PhysicalHashJoin) convertPartitionKeysIfNeed(lTask, rTask *mppTask) (*mppTask, *mppTask) {
lp := lTask.p
if _, ok := lp.(*PhysicalExchangeReceiver); ok {
lp = lp.Children()[0].Children()[0]
Expand All @@ -652,15 +652,15 @@ func (p *PhysicalHashJoin) convertJoinKeyForTiFlashIfNeed(lTask, rTask *mppTask)
if _, ok := rp.(*PhysicalExchangeReceiver); ok {
rp = rp.Children()[0].Children()[0]
}
// to mark if any equal cond needs to convert
lMask := make([]bool, len(p.EqualConditions))
rMask := make([]bool, len(p.EqualConditions))
cTypes := make([]*types.FieldType, len(p.EqualConditions))
// to mark if any partition key needs to convert
lMask := make([]bool, len(lTask.hashCols))
rMask := make([]bool, len(rTask.hashCols))
cTypes := make([]*types.FieldType, len(lTask.hashCols))
lChanged := false
rChanged := false
for i, eqFunc := range p.EqualConditions {
lKey := eqFunc.GetArgs()[0].(*expression.Column)
rKey := eqFunc.GetArgs()[1].(*expression.Column)
for i := range lTask.hashCols {
lKey := lTask.hashCols[i]
rKey := rTask.hashCols[i]
cType, lConvert, rConvert := negotiateCommonType(lKey.RetType, rKey.RetType)
if lConvert {
lMask[i] = true
Expand All @@ -685,14 +685,12 @@ func (p *PhysicalHashJoin) convertJoinKeyForTiFlashIfNeed(lTask, rTask *mppTask)
rProj = getProj(p.ctx, rp)
rp = rProj
}
newEqCondition := make([]*expression.ScalarFunction, 0, len(p.EqualConditions))
newEqCondition = append(newEqCondition, p.EqualConditions...)
p.EqualConditions = newEqCondition
lKeys := make([]*expression.Column, 0, len(p.EqualConditions))
rKeys := make([]*expression.Column, 0, len(p.EqualConditions))
for i, eqFunc := range p.EqualConditions {
lKey := eqFunc.GetArgs()[0].(*expression.Column)
rKey := eqFunc.GetArgs()[1].(*expression.Column)

lPartKeys := make([]*expression.Column, 0, len(rTask.hashCols))
rPartKeys := make([]*expression.Column, 0, len(lTask.hashCols))
for i := range lTask.hashCols {
lKey := lTask.hashCols[i]
rKey := rTask.hashCols[i]
if lMask[i] {
cType := cTypes[i].Clone()
cType.Flag = lKey.RetType.Flag
Expand All @@ -705,12 +703,8 @@ func (p *PhysicalHashJoin) convertJoinKeyForTiFlashIfNeed(lTask, rTask *mppTask)
rCast := expression.BuildCastFunction(p.ctx, rKey, cType)
rKey = appendExpr(rProj, rCast)
}
if lMask[i] || rMask[i] {
eqCond := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lKey, rKey)
p.EqualConditions[i] = eqCond.(*expression.ScalarFunction)
}
lKeys = append(lKeys, lKey)
rKeys = append(rKeys, rKey)
lPartKeys = append(lPartKeys, lKey)
rPartKeys = append(rPartKeys, rKey)
}
// if left or right child changes, we need to add enforcer.
if lChanged {
Expand All @@ -719,7 +713,7 @@ func (p *PhysicalHashJoin) convertJoinKeyForTiFlashIfNeed(lTask, rTask *mppTask)
nlTask = nlTask.enforceExchangerImpl(&property.PhysicalProperty{
TaskTp: property.MppTaskType,
PartitionTp: property.HashType,
PartitionCols: lKeys,
PartitionCols: lPartKeys,
})
nlTask.cst = lTask.cst
lTask = nlTask
Expand All @@ -730,7 +724,7 @@ func (p *PhysicalHashJoin) convertJoinKeyForTiFlashIfNeed(lTask, rTask *mppTask)
nrTask = nrTask.enforceExchangerImpl(&property.PhysicalProperty{
TaskTp: property.MppTaskType,
PartitionTp: property.HashType,
PartitionCols: rKeys,
PartitionCols: rPartKeys,
})
nrTask.cst = rTask.cst
rTask = nrTask
Expand All @@ -745,7 +739,11 @@ func (p *PhysicalHashJoin) attach2TaskForMpp(tasks ...task) task {
return invalidTask
}
if p.mppShuffleJoin {
lTask, rTask = p.convertJoinKeyForTiFlashIfNeed(lTask, rTask)
// protection check is case of some bugs
if len(lTask.hashCols) != len(rTask.hashCols) || len(lTask.hashCols) == 0 {
return invalidTask
}
lTask, rTask = p.convertPartitionKeysIfNeed(lTask, rTask)
}
p.SetChildren(lTask.plan(), rTask.plan())
p.schema = BuildPhysicalJoinSchema(p.JoinType, p)
Expand Down
3 changes: 2 additions & 1 deletion planner/core/testdata/integration_serial_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@
"desc format = 'brief' select * from t t1 join t t2 on t1.c1 = t2.c2 and t1.c2 = t2.c3 and t1.c3 = t2.c1 and t1.c4 = t2.c3 and t1.c1 = t2.c5",
"desc format = 'brief' select * from t t1 join t t2 on t1.c1 + t1.c2 = t2.c2 / t2.c3",
"desc format = 'brief' select * from t t1 where exists (select * from t t2 where t1.c1 = t2.c2 and t1.c2 = t2.c3 and t1.c3 = t2.c1 and t1.c4 = t2.c3 and t1.c1 = t2.c5)",
"desc format = 'brief' select * from t t1 left join t t2 on t1.c1 = t2.c2 join t t3 on t2.c5 = t3.c3 right join t t4 on t3.c3 = t4.c4 "
"desc format = 'brief' select * from t t1 left join t t2 on t1.c1 = t2.c2 join t t3 on t2.c5 = t3.c3 right join t t4 on t3.c3 = t4.c4 ",
"desc format = 'brief' SELECT STRAIGHT_JOIN t1 . col_varchar_64 , t1 . col_char_64_not_null FROM tt AS t1 INNER JOIN( tt AS t2 JOIN tt AS t3 ON(t3 . col_decimal_30_10_key = t2 . col_tinyint)) ON(t3 . col_varchar_64 = t2 . col_varchar_key) WHERE t3 . col_varchar_64 = t1 . col_char_64_not_null GROUP BY 1 , 2"
]
},
{
Expand Down
Loading