Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: store the hints of session variable (#45814) #46045

Open
wants to merge 4 commits into
base: release-6.1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions planner/core/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/codec"
Expand Down Expand Up @@ -181,6 +182,8 @@ type PlanCacheValue struct {
BinVarTypes []byte // variable types under binary protocol
IsBinProto bool // whether this plan is under binary protocol
BindSQL string
// stmtHints stores the hints which set session variables, because the hints won't be processed using cached plan.
stmtHints *stmtctx.StmtHints
}

func (v *PlanCacheValue) varTypesUnchanged(binVarTps []byte, txtVarTps []*types.FieldType) bool {
Expand All @@ -192,7 +195,7 @@ func (v *PlanCacheValue) varTypesUnchanged(binVarTps []byte, txtVarTps []*types.

// NewPlanCacheValue creates a SQLCacheValue.
func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.TableInfo]bool,
isBinProto bool, binVarTypes []byte, txtVarTps []*types.FieldType, bindSQL string) *PlanCacheValue {
isBinProto bool, binVarTypes []byte, txtVarTps []*types.FieldType, stmtCtx *stmtctx.StatementContext) *PlanCacheValue {
dstMap := make(map[*model.TableInfo]bool)
for k, v := range srcMap {
dstMap[k] = v
Expand All @@ -208,7 +211,8 @@ func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.Ta
TxtVarTypes: userVarTypes,
BinVarTypes: binVarTypes,
IsBinProto: isBinProto,
BindSQL: bindSQL,
BindSQL: stmtCtx.BindSQL,
stmtHints: stmtCtx.StmtHints.Clone(),
}
}

Expand Down
13 changes: 11 additions & 2 deletions planner/core/common_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,11 @@ func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context,
e.names = names
e.Plan = plan
stmtCtx.PointExec = true
if pointPlan, ok := plan.(*PointGetPlan); ok {
if pointPlan.stmtHints != nil {
stmtCtx.StmtHints = *pointPlan.stmtHints
}
}
return nil
}
if prepared.UseCache && !ignorePlanCache { // for general plans
Expand Down Expand Up @@ -564,6 +569,9 @@ func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context,
e.names = cachedVal.OutPutNames
e.Plan = cachedVal.Plan
stmtCtx.SetPlanDigest(preparedStmt.NormalizedPlan, preparedStmt.PlanDigest)
if cachedVal.stmtHints != nil {
stmtCtx.StmtHints = *cachedVal.stmtHints
}
return nil
}
break
Expand Down Expand Up @@ -601,7 +609,7 @@ REBUILD:
}
sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{}
}
cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, isBinProtocol, binVarTypes, txtVarTypes, sessVars.StmtCtx.BindSQL)
cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, isBinProtocol, binVarTypes, txtVarTypes, sessVars.StmtCtx)
preparedStmt.NormalizedPlan, preparedStmt.PlanDigest = NormalizePlan(p)
stmtCtx.SetPlanDigest(preparedStmt.NormalizedPlan, preparedStmt.PlanDigest)
if cacheVals, exists := sctx.PreparedPlanCache().Get(cacheKey); exists {
Expand Down Expand Up @@ -673,13 +681,14 @@ func (e *Execute) tryCachePointPlan(ctx context.Context, sctx sessionctx.Context
err error
names types.NameSlice
)
switch p.(type) {
switch pointPlan := p.(type) {
case *PointGetPlan:
ok, err = IsPointGetWithPKOrUniqueKeyByAutoCommit(sctx, p)
names = p.OutputNames()
if err != nil {
return err
}
pointPlan.stmtHints = sctx.GetSessionVars().StmtCtx.StmtHints.Clone()
}
if ok {
// just cache point plan now
Expand Down
2 changes: 2 additions & 0 deletions planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ type PointGetPlan struct {
planCost float64
// accessCols represents actual columns the PointGet will access, which are used to calculate row-size
accessCols []*expression.Column
// stmtHints should restore in executing context.
stmtHints *stmtctx.StmtHints
}

type nameValuePair struct {
Expand Down
36 changes: 35 additions & 1 deletion sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ type StatementContext struct {
type StmtHints struct {
// Hint Information
MemQuotaQuery int64
ApplyCacheCapacity int64
MaxExecutionTime uint64
ReplicaRead byte
AllowInSubqToJoinAndAgg bool
Expand Down Expand Up @@ -329,6 +328,41 @@ func (sh *StmtHints) TaskMapNeedBackUp() bool {
return sh.ForceNthPlan != -1
}

// Clone the StmtHints struct and returns the pointer of the new one.
func (sh *StmtHints) Clone() *StmtHints {
var (
vars map[string]string
tableHints []*ast.TableOptimizerHint
)
if len(sh.SetVars) > 0 {
vars = make(map[string]string, len(sh.SetVars))
for k, v := range sh.SetVars {
vars[k] = v
}
}
if len(sh.OriginalTableHints) > 0 {
tableHints = make([]*ast.TableOptimizerHint, len(sh.OriginalTableHints))
copy(tableHints, sh.OriginalTableHints)
}
return &StmtHints{
MemQuotaQuery: sh.MemQuotaQuery,
MaxExecutionTime: sh.MaxExecutionTime,
ReplicaRead: sh.ReplicaRead,
AllowInSubqToJoinAndAgg: sh.AllowInSubqToJoinAndAgg,
NoIndexMergeHint: sh.NoIndexMergeHint,
StraightJoinOrder: sh.StraightJoinOrder,
EnableCascadesPlanner: sh.EnableCascadesPlanner,
ForceNthPlan: sh.ForceNthPlan,
HasAllowInSubqToJoinAndAggHint: sh.HasAllowInSubqToJoinAndAggHint,
HasMemQuotaHint: sh.HasMemQuotaHint,
HasReplicaReadHint: sh.HasReplicaReadHint,
HasMaxExecutionTime: sh.HasMaxExecutionTime,
HasEnableCascadesPlannerHint: sh.HasEnableCascadesPlannerHint,
SetVars: vars,
OriginalTableHints: tableHints,
}
}

// StmtCacheKey represents the key type in the StmtCache.
type StmtCacheKey int

Expand Down
23 changes: 23 additions & 0 deletions sessionctx/stmtctx/stmtctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package stmtctx_test
import (
"context"
"fmt"
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -143,3 +144,25 @@ func TestWeakConsistencyRead(t *testing.T) {
execAndCheck("execute s", testkit.Rows("1 1 2"), kv.SI)
tk.MustExec("rollback")
}

func TestStmtHintsClone(t *testing.T) {
hints := stmtctx.StmtHints{}
value := reflect.ValueOf(&hints).Elem()
for i := 0; i < value.NumField(); i++ {
field := value.Field(i)
switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
field.SetInt(1)
case reflect.Uint, reflect.Uint32, reflect.Uint64:
field.SetUint(1)
case reflect.Uint8: // byte
field.SetUint(1)
case reflect.Bool:
field.SetBool(true)
case reflect.String:
field.SetString("test")
default:
}
}
require.Equal(t, hints, *hints.Clone())
}
62 changes: 62 additions & 0 deletions tests/realtikvtest/sessiontest/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/pingcap/tidb/privilege/privileges"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/store/copr"
"github.com/pingcap/tidb/store/mockstore"
Expand Down Expand Up @@ -3792,3 +3793,64 @@ func TestSQLModeOp(t *testing.T) {
a = mysql.SetSQLMode(s, mysql.ModeAllowInvalidDates)
require.Equal(t, mysql.ModeNoBackslashEscapes|mysql.ModeOnlyFullGroupBy|mysql.ModeAllowInvalidDates, a)
}

func TestPrepareExecuteWithSQLHints(t *testing.T) {
store, clean := realtikvtest.CreateMockStoreAndSetup(t)
defer clean()
tk := testkit.NewTestKit(t, store)
se := tk.Session()
se.SetConnectionID(1)
tk.MustExec("use test")
tk.MustExec("create table t(a int primary key)")

type hintCheck struct {
hint string
check func(*stmtctx.StmtHints)
}

hintChecks := []hintCheck{
{
hint: "MEMORY_QUOTA(1024 MB)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasMemQuotaHint)
require.Equal(t, int64(1024*1024*1024), stmtHint.MemQuotaQuery)
},
},
{
hint: "READ_CONSISTENT_REPLICA()",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasReplicaReadHint)
require.Equal(t, byte(kv.ReplicaReadFollower), stmtHint.ReplicaRead)
},
},
{
hint: "MAX_EXECUTION_TIME(1000)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasMaxExecutionTime)
require.Equal(t, uint64(1000), stmtHint.MaxExecutionTime)
},
},
{
hint: "USE_TOJA(TRUE)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasAllowInSubqToJoinAndAggHint)
require.True(t, stmtHint.AllowInSubqToJoinAndAgg)
},
},
}

for i, check := range hintChecks {
// common path
tk.MustExec(fmt.Sprintf("prepare stmt%d from 'select /*+ %s */ * from t'", i, check.hint))
for j := 0; j < 10; j++ {
tk.MustQuery(fmt.Sprintf("execute stmt%d", i))
check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints)
}
// fast path
tk.MustExec(fmt.Sprintf("prepare fast%d from 'select /*+ %s */ * from t where a = 1'", i, check.hint))
for j := 0; j < 10; j++ {
tk.MustQuery(fmt.Sprintf("execute fast%d", i))
check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints)
}
}
}