Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: fix a bug that the pessimistic lock doesn't work on a partition (#14921) #15114

Merged
merged 4 commits into from
Mar 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,9 @@ func (b *executorBuilder) buildSelectLock(v *plannercore.PhysicalLock) Executor
return src
}
e := &SelectLockExec{
baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), src),
Lock: v.Lock,
baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), src),
Lock: v.Lock,
partitionedTable: v.PartitionedTable,
}
return e
}
Expand Down
38 changes: 33 additions & 5 deletions executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,11 @@ type SelectLockExec struct {

Lock ast.SelectLockType
keys []kv.Key

partitionedTable []table.PartitionedTable

// tblID2Table is cached to reduce cost.
tblID2Table map[int64]table.PartitionedTable
}

// Open implements the Executor Open interface.
Expand All @@ -755,6 +760,18 @@ func (e *SelectLockExec) Open(ctx context.Context) error {
// This operation is only for schema validator check.
txnCtx.UpdateDeltaForTable(id, 0, 0, map[int64]int64{})
}

if len(e.Schema().TblID2Handle) > 0 && len(e.partitionedTable) > 0 {
e.tblID2Table = make(map[int64]table.PartitionedTable, len(e.partitionedTable))
for id := range e.Schema().TblID2Handle {
for _, p := range e.partitionedTable {
if id == p.Meta().ID {
e.tblID2Table[id] = p
}
}
}
}

return nil
}

Expand All @@ -774,12 +791,23 @@ func (e *SelectLockExec) Next(ctx context.Context, req *chunk.Chunk) error {
if len(e.Schema().TblID2Handle) == 0 || e.Lock != ast.SelectLockForUpdate {
return nil
}
if req.NumRows() != 0 {

if req.NumRows() > 0 {
iter := chunk.NewIterator4Chunk(req)
for id, cols := range e.Schema().TblID2Handle {
for _, col := range cols {
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
e.keys = append(e.keys, tablecodec.EncodeRowKeyWithHandle(id, row.GetInt64(col.Index)))
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
for id, cols := range e.Schema().TblID2Handle {
physicalID := id
if pt, ok := e.tblID2Table[id]; ok {
// On a partitioned table, we have to use physical ID to encode the lock key!
p, err := pt.GetPartitionByRow(e.ctx, row.GetDatumRow(e.base().retFieldTypes))
if err != nil {
return err
}
physicalID = p.GetPhysicalID()
}

for _, col := range cols {
e.keys = append(e.keys, tablecodec.EncodeRowKeyWithHandle(physicalID, row.GetInt64(col.Index)))
}
}
}
Expand Down
12 changes: 11 additions & 1 deletion executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,17 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
if ctx.GetSessionVars().ClientCapability&mysql.ClientFoundRows > 0 {
sc.AddAffectedRows(1)
}
unchangedRowKey := tablecodec.EncodeRowKeyWithHandle(t.Meta().ID, h)

physicalID := t.Meta().ID
if pt, ok := t.(table.PartitionedTable); ok {
p, err := pt.GetPartitionByRow(ctx, oldData)
if err != nil {
return false, false, 0, err
}
physicalID = p.GetPhysicalID()
}

unchangedRowKey := tablecodec.EncodeRowKeyWithHandle(physicalID, h)
txnCtx := ctx.GetSessionVars().TxnCtx
if txnCtx.IsPessimistic {
txnCtx.AddUnchangedRowKey(unchangedRowKey)
Expand Down
20 changes: 10 additions & 10 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func (c *concatFunctionClass) getFunction(ctx sessionctx.Context, args []Express
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

sig := &builtinConcatSig{bf, maxAllowedPacket}
Expand Down Expand Up @@ -354,7 +354,7 @@ func (c *concatWSFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

sig := &builtinConcatWSSig{bf, maxAllowedPacket}
Expand Down Expand Up @@ -591,7 +591,7 @@ func (c *repeatFunctionClass) getFunction(ctx sessionctx.Context, args []Express
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}
sig := &builtinRepeatSig{bf, maxAllowedPacket}
return sig, nil
Expand Down Expand Up @@ -758,7 +758,7 @@ func (c *spaceFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}
sig := &builtinSpaceSig{bf, maxAllowedPacket}
return sig, nil
Expand Down Expand Up @@ -1853,7 +1853,7 @@ func (c *lpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) {
Expand Down Expand Up @@ -1981,7 +1981,7 @@ func (c *rpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) {
Expand Down Expand Up @@ -2607,7 +2607,7 @@ func (b *builtinOctStringSig) evalString(row chunk.Row) (string, bool, error) {
if err != nil {
numError, ok := err.(*strconv.NumError)
if !ok || numError.Err != strconv.ErrRange {
return "", true, err
return "", true, errors.Trace(err)
}
overflow = true
}
Expand Down Expand Up @@ -3161,7 +3161,7 @@ func (c *fromBase64FunctionClass) getFunction(ctx sessionctx.Context, args []Exp
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

types.SetBinChsClnFlag(bf.tp)
Expand Down Expand Up @@ -3234,7 +3234,7 @@ func (c *toBase64FunctionClass) getFunction(ctx sessionctx.Context, args []Expre
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

sig := &builtinToBase64Sig{bf, maxAllowedPacket}
Expand Down Expand Up @@ -3335,7 +3335,7 @@ func (c *insertFunctionClass) getFunction(ctx sessionctx.Context, args []Express
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) {
Expand Down
3 changes: 2 additions & 1 deletion planner/core/exhaust_physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -1280,7 +1280,8 @@ func (p *LogicalLimit) exhaustPhysicalPlans(prop *property.PhysicalProperty) []P
func (p *LogicalLock) exhaustPhysicalPlans(prop *property.PhysicalProperty) []PhysicalPlan {
childProp := prop.Clone()
lock := PhysicalLock{
Lock: p.Lock,
Lock: p.Lock,
PartitionedTable: p.partitionedTable,
}.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp)
return []PhysicalPlan{lock}
}
Expand Down
1 change: 1 addition & 0 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2239,6 +2239,7 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName) (L

if tableInfo.GetPartitionInfo() != nil {
b.optFlag = b.optFlag | flagPartitionProcessor
b.partitionedTable = append(b.partitionedTable, tbl.(table.PartitionedTable))
// check partition by name.
for _, name := range tn.PartitionNames {
_, err = tables.FindPartitionByName(tableInfo, name.L)
Expand Down
3 changes: 2 additions & 1 deletion planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,8 @@ type LogicalLimit struct {
type LogicalLock struct {
baseLogicalPlan

Lock ast.SelectLockType
Lock ast.SelectLockType
partitionedTable []table.PartitionedTable
}

// WindowFrame represents a window function frame.
Expand Down
3 changes: 3 additions & 0 deletions planner/core/physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/pingcap/tidb/planner/property"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/statistics"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/ranger"
)
Expand Down Expand Up @@ -280,6 +281,8 @@ type PhysicalLock struct {
basePhysicalPlan

Lock ast.SelectLockType

PartitionedTable []table.PartitionedTable
}

// PhysicalLimit is the physical operator of Limit.
Expand Down
5 changes: 4 additions & 1 deletion planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ type PlanBuilder struct {
inStraightJoin bool

windowSpecs map[string]*ast.WindowSpec

// SelectLock need this information to locate the lock on partitions.
partitionedTable []table.PartitionedTable
}

// GetVisitInfo gets the visitInfo of the PlanBuilder.
Expand Down Expand Up @@ -575,7 +578,7 @@ func removeIgnoredPaths(paths, ignoredPaths []*accessPath, tblInfo *model.TableI
}

func (b *PlanBuilder) buildSelectLock(src LogicalPlan, lock ast.SelectLockType) *LogicalLock {
selectLock := LogicalLock{Lock: lock}.Init(b.ctx)
selectLock := LogicalLock{Lock: lock, partitionedTable: b.partitionedTable}.Init(b.ctx)
selectLock.SetChildren(src)
return selectLock
}
Expand Down
6 changes: 6 additions & 0 deletions planner/core/rule_column_pruning.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,12 @@ func (p *LogicalLock) PruneColumns(parentUsedCols []*expression.Column) error {
return p.baseLogicalPlan.PruneColumns(parentUsedCols)
}

if len(p.partitionedTable) > 0 {
// If the children include partitioned tables, do not prune columns.
// Because the executor needs the partitioned columns to calculate the lock key.
return p.children[0].PruneColumns(p.Schema().Columns)
}

for _, cols := range p.children[0].Schema().TblID2Handle {
parentUsedCols = append(parentUsedCols, cols...)
}
Expand Down
4 changes: 4 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,10 @@ func createSessionFunc(store kv.Storage) pools.Factory {
if err != nil {
return nil, errors.Trace(err)
}
err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxAllowedPacket, types.NewStringDatum("67108864"))
if err != nil {
return nil, errors.Trace(err)
}
se.sessionVars.CommonGlobalLoaded = true
se.sessionVars.InRestrictedSQL = true
return se, nil
Expand Down
66 changes: 66 additions & 0 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2799,3 +2799,69 @@ func (s *testSessionSuite) TestGrantViewRelated(c *C) {
tkUser.MustQuery("select current_user();").Check(testkit.Rows("u_version29@%"))
tkUser.MustExec("create view v_version29_c as select * from v_version29;")
}

func (s *testSessionSuite) TestPessimisticLockOnPartition(c *C) {
// This test checks that 'select ... for update' locks the partition instead of the table.
// Cover a bug that table ID is used to encode the lock key mistakenly.
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec(`create table if not exists forupdate_on_partition (
age int not null primary key,
nickname varchar(20) not null,
gender int not null default 0,
first_name varchar(30) not null default '',
last_name varchar(20) not null default '',
full_name varchar(60) as (concat(first_name, ' ', last_name)),
index idx_nickname (nickname)
) partition by range (age) (
partition child values less than (18),
partition young values less than (30),
partition middle values less than (50),
partition old values less than (123)
);`)
tk.MustExec("insert into forupdate_on_partition (`age`, `nickname`) values (25, 'cosven');")

tk1 := testkit.NewTestKit(c, s.store)
tk1.MustExec("use test")

tk.MustExec("begin pessimistic")
tk.MustQuery("select * from forupdate_on_partition where age=25 for update").Check(testkit.Rows("25 cosven 0 "))
tk1.MustExec("begin pessimistic")

ch := make(chan int32, 5)
go func() {
tk1.MustExec("update forupdate_on_partition set first_name='sw' where age=25")
ch <- 0
tk1.MustExec("commit")
}()

// Leave 50ms for tk1 to run, tk1 should be blocked at the update operation.
time.Sleep(50 * time.Millisecond)
ch <- 1

tk.MustExec("commit")
// tk1 should be blocked until tk commit, check the order.
c.Assert(<-ch, Equals, int32(1))
c.Assert(<-ch, Equals, int32(0))

// Once again...
// This time, test for the update-update conflict.
tk.MustExec("begin pessimistic")
tk.MustExec("update forupdate_on_partition set first_name='sw' where age=25")
tk1.MustExec("begin pessimistic")

go func() {
tk1.MustExec("update forupdate_on_partition set first_name = 'xxx' where age=25")
ch <- 0
tk1.MustExec("commit")
}()

// Leave 50ms for tk1 to run, tk1 should be blocked at the update operation.
time.Sleep(50 * time.Millisecond)
ch <- 1

tk.MustExec("commit")
// tk1 should be blocked until tk commit, check the order.
c.Assert(<-ch, Equals, int32(1))
c.Assert(<-ch, Equals, int32(0))
}
2 changes: 1 addition & 1 deletion table/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ type PhysicalTable interface {
type PartitionedTable interface {
Table
GetPartition(physicalID int64) PhysicalTable
GetPartitionByRow(sessionctx.Context, []types.Datum) (Table, error)
GetPartitionByRow(sessionctx.Context, []types.Datum) (PhysicalTable, error)
}

// TableFromMeta builds a table.Table from *model.TableInfo.
Expand Down
2 changes: 1 addition & 1 deletion table/tables/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func (t *partitionedTable) GetPartition(pid int64) table.PhysicalTable {
}

// GetPartitionByRow returns a Table, which is actually a Partition.
func (t *partitionedTable) GetPartitionByRow(ctx sessionctx.Context, r []types.Datum) (table.Table, error) {
func (t *partitionedTable) GetPartitionByRow(ctx sessionctx.Context, r []types.Datum) (table.PhysicalTable, error) {
pid, err := t.locatePartition(ctx, t.Meta().GetPartitionInfo(), r)
if err != nil {
return nil, errors.Trace(err)
Expand Down