Skip to content

Commit

Permalink
planner: store the hints of session variable (#45814) (#46047)
Browse files Browse the repository at this point in the history
close #45812
  • Loading branch information
ti-chi-bot authored Oct 16, 2023
1 parent d495d63 commit c3b8757
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 5 deletions.
11 changes: 9 additions & 2 deletions planner/core/plan_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ func getCachedPointPlan(stmt *ast.Prepared, sessVars *variable.SessionVars, stmt
}
sessVars.FoundInPlanCache = true
stmtCtx.PointExec = true
if pointGetPlan, ok := plan.(*PointGetPlan); ok && pointGetPlan != nil && pointGetPlan.stmtHints != nil {
sessVars.StmtCtx.StmtHints = *pointGetPlan.stmtHints
}
return plan, names, true, nil
}

Expand Down Expand Up @@ -287,6 +290,7 @@ func getCachedPlan(sctx sessionctx.Context, isNonPrepared bool, cacheKey kvcache
core_metrics.GetPlanCacheHitCounter(isNonPrepared).Inc()
}
stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest)
stmtCtx.StmtHints = *cachedVal.stmtHints
return cachedVal.Plan, cachedVal.OutPutNames, true, nil
}

Expand Down Expand Up @@ -329,7 +333,7 @@ func generateNewPlan(ctx context.Context, sctx sessionctx.Context, isNonPrepared
}
sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{}
}
cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, matchOpts)
cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, matchOpts, &stmtCtx.StmtHints)
stmt.NormalizedPlan, stmt.PlanDigest = NormalizePlan(p)
stmtCtx.SetPlan(p)
stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest)
Expand Down Expand Up @@ -759,12 +763,15 @@ func tryCachePointPlan(_ context.Context, sctx sessionctx.Context,
names types.NameSlice
)

if _, _ok := p.(*PointGetPlan); _ok {
if plan, _ok := p.(*PointGetPlan); _ok {
ok, err = IsPointGetWithPKOrUniqueKeyByAutoCommit(sctx, p)
names = p.OutputNames()
if err != nil {
return err
}
if ok {
plan.stmtHints = sctx.GetSessionVars().StmtCtx.StmtHints.Clone()
}
}

if ok {
Expand Down
6 changes: 5 additions & 1 deletion planner/core/plan_cache_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/statistics"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -335,6 +336,8 @@ type PlanCacheValue struct {

// matchOpts stores some fields help to choose a suitable plan
matchOpts *utilpc.PlanCacheMatchOpts
// stmtHints stores the hints which set session variables, because the hints won't be processed using cached plan.
stmtHints *stmtctx.StmtHints
}

func (v *PlanCacheValue) varTypesUnchanged(txtVarTps []*types.FieldType) bool {
Expand Down Expand Up @@ -385,7 +388,7 @@ func (v *PlanCacheValue) MemoryUsage() (sum int64) {

// NewPlanCacheValue creates a SQLCacheValue.
func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.TableInfo]bool,
matchOpts *utilpc.PlanCacheMatchOpts) *PlanCacheValue {
matchOpts *utilpc.PlanCacheMatchOpts, stmtHints *stmtctx.StmtHints) *PlanCacheValue {
dstMap := make(map[*model.TableInfo]bool)
for k, v := range srcMap {
dstMap[k] = v
Expand All @@ -399,6 +402,7 @@ func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.Ta
OutPutNames: names,
TblInfo2UnionScan: dstMap,
matchOpts: matchOpts,
stmtHints: stmtHints.Clone(),
}
}

Expand Down
2 changes: 2 additions & 0 deletions planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ type PointGetPlan struct {
// probeParents records the IndexJoins and Applys with this operator in their inner children.
// Please see comments in PhysicalPlan for details.
probeParents []PhysicalPlan
// stmtHints should restore in executing context.
stmtHints *stmtctx.StmtHints
}

func (p *PointGetPlan) getEstRowCountForDisplay() float64 {
Expand Down
1 change: 1 addition & 0 deletions session/sessiontest/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ go_test(
"//privilege/privileges",
"//session",
"//sessionctx",
"//sessionctx/stmtctx",
"//sessionctx/variable",
"//store/copr",
"//store/mockstore",
Expand Down
68 changes: 68 additions & 0 deletions session/sessiontest/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/pingcap/tidb/privilege/privileges"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/store/copr"
"github.com/pingcap/tidb/store/mockstore"
Expand Down Expand Up @@ -3602,3 +3603,70 @@ func TestSQLModeOp(t *testing.T) {
a = mysql.SetSQLMode(s, mysql.ModeAllowInvalidDates)
require.Equal(t, mysql.ModeNoBackslashEscapes|mysql.ModeOnlyFullGroupBy|mysql.ModeAllowInvalidDates, a)
}

func TestPrepareExecuteWithSQLHints(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
se := tk.Session()
se.SetConnectionID(1)
tk.MustExec("use test")
tk.MustExec("create table t(a int primary key)")

type hintCheck struct {
hint string
check func(*stmtctx.StmtHints)
}

hintChecks := []hintCheck{
{
hint: "MEMORY_QUOTA(1024 MB)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasMemQuotaHint)
require.Equal(t, int64(1024*1024*1024), stmtHint.MemQuotaQuery)
},
},
{
hint: "READ_CONSISTENT_REPLICA()",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasReplicaReadHint)
require.Equal(t, byte(kv.ReplicaReadFollower), stmtHint.ReplicaRead)
},
},
{
hint: "MAX_EXECUTION_TIME(1000)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasMaxExecutionTime)
require.Equal(t, uint64(1000), stmtHint.MaxExecutionTime)
},
},
{
hint: "USE_TOJA(TRUE)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasAllowInSubqToJoinAndAggHint)
require.True(t, stmtHint.AllowInSubqToJoinAndAgg)
},
},
{
hint: "RESOURCE_GROUP(rg1)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasResourceGroup)
require.Equal(t, "rg1", stmtHint.ResourceGroup)
},
},
}

for i, check := range hintChecks {
// common path
tk.MustExec(fmt.Sprintf("prepare stmt%d from 'select /*+ %s */ * from t'", i, check.hint))
for j := 0; j < 10; j++ {
tk.MustQuery(fmt.Sprintf("execute stmt%d", i))
check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints)
}
// fast path
tk.MustExec(fmt.Sprintf("prepare fast%d from 'select /*+ %s */ * from t where a = 1'", i, check.hint))
for j := 0; j < 10; j++ {
tk.MustQuery(fmt.Sprintf("execute fast%d", i))
check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints)
}
}
}
2 changes: 1 addition & 1 deletion sessionctx/stmtctx/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ go_test(
],
embed = [":stmtctx"],
flaky = True,
shard_count = 5,
shard_count = 6,
deps = [
"//kv",
"//sessionctx/variable",
Expand Down
38 changes: 37 additions & 1 deletion sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ type StatementContext struct {
type StmtHints struct {
// Hint Information
MemQuotaQuery int64
ApplyCacheCapacity int64
MaxExecutionTime uint64
ReplicaRead byte
AllowInSubqToJoinAndAgg bool
Expand Down Expand Up @@ -446,6 +445,43 @@ func (sh *StmtHints) TaskMapNeedBackUp() bool {
return sh.ForceNthPlan != -1
}

// Clone the StmtHints struct and returns the pointer of the new one.
func (sh *StmtHints) Clone() *StmtHints {
var (
vars map[string]string
tableHints []*ast.TableOptimizerHint
)
if len(sh.SetVars) > 0 {
vars = make(map[string]string, len(sh.SetVars))
for k, v := range sh.SetVars {
vars[k] = v
}
}
if len(sh.OriginalTableHints) > 0 {
tableHints = make([]*ast.TableOptimizerHint, len(sh.OriginalTableHints))
copy(tableHints, sh.OriginalTableHints)
}
return &StmtHints{
MemQuotaQuery: sh.MemQuotaQuery,
MaxExecutionTime: sh.MaxExecutionTime,
ReplicaRead: sh.ReplicaRead,
AllowInSubqToJoinAndAgg: sh.AllowInSubqToJoinAndAgg,
NoIndexMergeHint: sh.NoIndexMergeHint,
StraightJoinOrder: sh.StraightJoinOrder,
EnableCascadesPlanner: sh.EnableCascadesPlanner,
ForceNthPlan: sh.ForceNthPlan,
ResourceGroup: sh.ResourceGroup,
HasAllowInSubqToJoinAndAggHint: sh.HasAllowInSubqToJoinAndAggHint,
HasMemQuotaHint: sh.HasMemQuotaHint,
HasReplicaReadHint: sh.HasReplicaReadHint,
HasMaxExecutionTime: sh.HasMaxExecutionTime,
HasEnableCascadesPlannerHint: sh.HasEnableCascadesPlannerHint,
HasResourceGroup: sh.HasResourceGroup,
SetVars: vars,
OriginalTableHints: tableHints,
}
}

// StmtCacheKey represents the key type in the StmtCache.
type StmtCacheKey int

Expand Down
23 changes: 23 additions & 0 deletions sessionctx/stmtctx/stmtctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"fmt"
"math/rand"
"reflect"
"sort"
"testing"
"time"
Expand Down Expand Up @@ -273,3 +274,25 @@ func TestApproxRuntimeInfo(t *testing.T) {
require.Equal(t, d.TotBackoffTime[backoff], timeSum)
}
}

func TestStmtHintsClone(t *testing.T) {
hints := stmtctx.StmtHints{}
value := reflect.ValueOf(&hints).Elem()
for i := 0; i < value.NumField(); i++ {
field := value.Field(i)
switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
field.SetInt(1)
case reflect.Uint, reflect.Uint32, reflect.Uint64:
field.SetUint(1)
case reflect.Uint8: // byte
field.SetUint(1)
case reflect.Bool:
field.SetBool(true)
case reflect.String:
field.SetString("test")
default:
}
}
require.Equal(t, hints, *hints.Clone())
}

0 comments on commit c3b8757

Please sign in to comment.