From 5b917591642a4491550a63d3bb77ea1911511ea1 Mon Sep 17 00:00:00 2001 From: you06 Date: Mon, 14 Aug 2023 12:42:28 +0800 Subject: [PATCH] This is an automated cherry-pick of #45814 Signed-off-by: ti-chi-bot --- planner/core/plan_cache.go | 856 +++++++++++++++++++++++++++++ planner/core/plan_cache_utils.go | 665 ++++++++++++++++++++++ planner/core/point_get_plan.go | 27 + session/test/vars/BUILD.bazel | 30 + session/test/vars/vars_test.go | 685 +++++++++++++++++++++++ sessionctx/stmtctx/BUILD.bazel | 54 ++ sessionctx/stmtctx/stmtctx.go | 40 +- sessionctx/stmtctx/stmtctx_test.go | 154 ++++++ 8 files changed, 2510 insertions(+), 1 deletion(-) create mode 100644 planner/core/plan_cache.go create mode 100644 planner/core/plan_cache_utils.go create mode 100644 session/test/vars/BUILD.bazel create mode 100644 session/test/vars/vars_test.go create mode 100644 sessionctx/stmtctx/BUILD.bazel diff --git a/planner/core/plan_cache.go b/planner/core/plan_cache.go new file mode 100644 index 0000000000000..8623f59173c12 --- /dev/null +++ b/planner/core/plan_cache.go @@ -0,0 +1,856 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/bindinfo" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/mysql" + core_metrics "github.com/pingcap/tidb/planner/core/metrics" + "github.com/pingcap/tidb/planner/util/debugtrace" + "github.com/pingcap/tidb/privilege" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/sessiontxn/staleread" + "github.com/pingcap/tidb/table/tables" + "github.com/pingcap/tidb/types" + driver "github.com/pingcap/tidb/types/parser_driver" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/collate" + "github.com/pingcap/tidb/util/kvcache" + utilpc "github.com/pingcap/tidb/util/plancache" + "github.com/pingcap/tidb/util/ranger" +) + +var ( + // PlanCacheKeyTestIssue43667 is for test. + PlanCacheKeyTestIssue43667 struct{} +) + +// SetParameterValuesIntoSCtx sets these parameters into session context. +func SetParameterValuesIntoSCtx(sctx sessionctx.Context, isNonPrep bool, markers []ast.ParamMarkerExpr, params []expression.Expression) error { + vars := sctx.GetSessionVars() + vars.PlanCacheParams.Reset() + for i, usingParam := range params { + val, err := usingParam.Eval(chunk.Row{}) + if err != nil { + return err + } + if isGetVarBinaryLiteral(sctx, usingParam) { + binVal, convErr := val.ToBytes() + if convErr != nil { + return convErr + } + val.SetBinaryLiteral(binVal) + } + if markers != nil { + param := markers[i].(*driver.ParamMarkerExpr) + param.Datum = val + param.InExecute = true + } + vars.PlanCacheParams.Append(val) + } + if vars.StmtCtx.EnableOptimizerDebugTrace && len(vars.PlanCacheParams.AllParamValues()) > 0 { + vals := vars.PlanCacheParams.AllParamValues() + valStrs := make([]string, len(vals)) + for i, val := range vals { + valStrs[i] = val.String() + } + debugtrace.RecordAnyValuesWithNames(sctx, "Parameter datums for EXECUTE", valStrs) + } + vars.PlanCacheParams.SetForNonPrepCache(isNonPrep) + return nil +} + +func planCachePreprocess(ctx context.Context, sctx sessionctx.Context, isNonPrepared bool, is infoschema.InfoSchema, stmt *PlanCacheStmt, params []expression.Expression) error { + vars := sctx.GetSessionVars() + stmtAst := stmt.PreparedAst + vars.StmtCtx.StmtType = stmtAst.StmtType + + // step 1: check parameter number + if len(stmtAst.Params) != len(params) { + return errors.Trace(ErrWrongParamCount) + } + + // step 2: set parameter values + if err := SetParameterValuesIntoSCtx(sctx, isNonPrepared, stmtAst.Params, params); err != nil { + return errors.Trace(err) + } + + // step 3: check schema version + if stmtAst.SchemaVersion != is.SchemaMetaVersion() { + // In order to avoid some correctness issues, we have to clear the + // cached plan once the schema version is changed. + // Cached plan in prepared struct does NOT have a "cache key" with + // schema version like prepared plan cache key + stmtAst.CachedPlan = nil + stmt.Executor = nil + stmt.ColumnInfos = nil + // If the schema version has changed we need to preprocess it again, + // if this time it failed, the real reason for the error is schema changed. + // Example: + // When running update in prepared statement's schema version distinguished from the one of execute statement + // We should reset the tableRefs in the prepared update statements, otherwise, the ast nodes still hold the old + // tableRefs columnInfo which will cause chaos in logic of trying point get plan. (should ban non-public column) + ret := &PreprocessorReturn{InfoSchema: is} + err := Preprocess(ctx, sctx, stmtAst.Stmt, InPrepare, WithPreprocessorReturn(ret)) + if err != nil { + return ErrSchemaChanged.GenWithStack("Schema change caused error: %s", err.Error()) + } + stmtAst.SchemaVersion = is.SchemaMetaVersion() + } + + // step 4: handle expiration + // If the lastUpdateTime less than expiredTimeStamp4PC, + // it means other sessions have executed 'admin flush instance plan_cache'. + // So we need to clear the current session's plan cache. + // And update lastUpdateTime to the newest one. + expiredTimeStamp4PC := domain.GetDomain(sctx).ExpiredTimeStamp4PC() + if stmt.StmtCacheable && expiredTimeStamp4PC.Compare(vars.LastUpdateTime4PC) > 0 { + sctx.GetSessionPlanCache().DeleteAll() + stmtAst.CachedPlan = nil + vars.LastUpdateTime4PC = expiredTimeStamp4PC + } + return nil +} + +// GetPlanFromSessionPlanCache is the entry point of Plan Cache. +// It tries to get a valid cached plan from this session's plan cache. +// If there is no such a plan, it'll call the optimizer to generate a new one. +// isNonPrepared indicates whether to use the non-prepared plan cache or the prepared plan cache. +func GetPlanFromSessionPlanCache(ctx context.Context, sctx sessionctx.Context, + isNonPrepared bool, is infoschema.InfoSchema, stmt *PlanCacheStmt, + params []expression.Expression) (plan Plan, names []*types.FieldName, err error) { + if err := planCachePreprocess(ctx, sctx, isNonPrepared, is, stmt, params); err != nil { + return nil, nil, err + } + + var cacheKey kvcache.Key + sessVars := sctx.GetSessionVars() + stmtCtx := sessVars.StmtCtx + stmtAst := stmt.PreparedAst + stmtCtx.UseCache = stmt.StmtCacheable + if isNonPrepared { + stmtCtx.CacheType = stmtctx.SessionNonPrepared + } else { + stmtCtx.CacheType = stmtctx.SessionPrepared + } + if !stmt.StmtCacheable { + stmtCtx.SetSkipPlanCache(errors.New(stmt.UncacheableReason)) + } + + var bindSQL string + if stmtCtx.UseCache { + var ignoreByBinding bool + bindSQL, ignoreByBinding = GetBindSQL4PlanCache(sctx, stmt) + if ignoreByBinding { + stmtCtx.SetSkipPlanCache(errors.Errorf("ignore plan cache by binding")) + } + } + + // In rc or for update read, we need the latest schema version to decide whether we need to + // rebuild the plan. So we set this value in rc or for update read. In other cases, let it be 0. + var latestSchemaVersion int64 + + if stmtCtx.UseCache { + if sctx.GetSessionVars().IsIsolation(ast.ReadCommitted) || stmt.ForUpdateRead { + // In Rc or ForUpdateRead, we should check if the information schema has been changed since + // last time. If it changed, we should rebuild the plan. Here, we use a different and more + // up-to-date schema version which can lead plan cache miss and thus, the plan will be rebuilt. + latestSchemaVersion = domain.GetDomain(sctx).InfoSchema().SchemaMetaVersion() + } + if cacheKey, err = NewPlanCacheKey(sctx.GetSessionVars(), stmt.StmtText, + stmt.StmtDB, stmtAst.SchemaVersion, latestSchemaVersion, bindSQL, expression.ExprPushDownBlackListReloadTimeStamp.Load()); err != nil { + return nil, nil, err + } + } + + if stmtCtx.UseCache && stmtAst.CachedPlan != nil { // special code path for fast point plan + if plan, names, ok, err := getCachedPointPlan(stmtAst, sessVars, stmtCtx); ok { + return plan, names, err + } + } + + matchOpts, err := GetMatchOpts(sctx, is, stmt, params) + if err != nil { + return nil, nil, err + } + if stmtCtx.UseCache { // for non-point plans + if plan, names, ok, err := getCachedPlan(sctx, isNonPrepared, cacheKey, bindSQL, is, stmt, matchOpts); err != nil || ok { + return plan, names, err + } + } + + return generateNewPlan(ctx, sctx, isNonPrepared, is, stmt, cacheKey, latestSchemaVersion, bindSQL, matchOpts) +} + +// parseParamTypes get parameters' types in PREPARE statement +func parseParamTypes(sctx sessionctx.Context, params []expression.Expression) (paramTypes []*types.FieldType) { + paramTypes = make([]*types.FieldType, 0, len(params)) + for _, param := range params { + if c, ok := param.(*expression.Constant); ok { // from binary protocol + paramTypes = append(paramTypes, c.GetType()) + continue + } + + // from text protocol, there must be a GetVar function + name := param.(*expression.ScalarFunction).GetArgs()[0].String() + tp, ok := sctx.GetSessionVars().GetUserVarType(name) + if !ok { + tp = types.NewFieldType(mysql.TypeNull) + } + paramTypes = append(paramTypes, tp) + } + return +} + +func getCachedPointPlan(stmt *ast.Prepared, sessVars *variable.SessionVars, stmtCtx *stmtctx.StatementContext) (Plan, + []*types.FieldName, bool, error) { + // short path for point-get plans + // Rewriting the expression in the select.where condition will convert its + // type from "paramMarker" to "Constant".When Point Select queries are executed, + // the expression in the where condition will not be evaluated, + // so you don't need to consider whether prepared.useCache is enabled. + plan := stmt.CachedPlan.(Plan) + names := stmt.CachedNames.(types.NameSlice) + if !RebuildPlan4CachedPlan(plan) { + return nil, nil, false, nil + } + if metrics.ResettablePlanCacheCounterFortTest { + metrics.PlanCacheCounter.WithLabelValues("prepare").Inc() + } else { + // only for prepared plan cache + core_metrics.GetPlanCacheHitCounter(false).Inc() + } + sessVars.FoundInPlanCache = true + stmtCtx.PointExec = true + if pointGetPlan, ok := plan.(*PointGetPlan); ok && pointGetPlan != nil && pointGetPlan.stmtHints != nil { + sessVars.StmtCtx.StmtHints = *pointGetPlan.stmtHints + } + return plan, names, true, nil +} + +func getCachedPlan(sctx sessionctx.Context, isNonPrepared bool, cacheKey kvcache.Key, bindSQL string, + is infoschema.InfoSchema, stmt *PlanCacheStmt, matchOpts *utilpc.PlanCacheMatchOpts) (Plan, + []*types.FieldName, bool, error) { + sessVars := sctx.GetSessionVars() + stmtCtx := sessVars.StmtCtx + + candidate, exist := sctx.GetSessionPlanCache().Get(cacheKey, matchOpts) + if !exist { + return nil, nil, false, nil + } + cachedVal := candidate.(*PlanCacheValue) + if err := CheckPreparedPriv(sctx, stmt, is); err != nil { + return nil, nil, false, err + } + for tblInfo, unionScan := range cachedVal.TblInfo2UnionScan { + if !unionScan && tableHasDirtyContent(sctx, tblInfo) { + // TODO we can inject UnionScan into cached plan to avoid invalidating it, though + // rebuilding the filters in UnionScan is pretty trivial. + sctx.GetSessionPlanCache().Delete(cacheKey) + return nil, nil, false, nil + } + } + if !RebuildPlan4CachedPlan(cachedVal.Plan) { + return nil, nil, false, nil + } + sessVars.FoundInPlanCache = true + if len(bindSQL) > 0 { + // When the `len(bindSQL) > 0`, it means we use the binding. + // So we need to record this. + sessVars.FoundInBinding = true + } + if metrics.ResettablePlanCacheCounterFortTest { + metrics.PlanCacheCounter.WithLabelValues("prepare").Inc() + } else { + core_metrics.GetPlanCacheHitCounter(isNonPrepared).Inc() + } + stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest) + stmtCtx.StmtHints = *cachedVal.stmtHints + return cachedVal.Plan, cachedVal.OutPutNames, true, nil +} + +// generateNewPlan call the optimizer to generate a new plan for current statement +// and try to add it to cache +func generateNewPlan(ctx context.Context, sctx sessionctx.Context, isNonPrepared bool, is infoschema.InfoSchema, + stmt *PlanCacheStmt, cacheKey kvcache.Key, latestSchemaVersion int64, bindSQL string, + matchOpts *utilpc.PlanCacheMatchOpts) (Plan, []*types.FieldName, error) { + stmtAst := stmt.PreparedAst + sessVars := sctx.GetSessionVars() + stmtCtx := sessVars.StmtCtx + + core_metrics.GetPlanCacheMissCounter(isNonPrepared).Inc() + sctx.GetSessionVars().StmtCtx.InPreparedPlanBuilding = true + p, names, err := OptimizeAstNode(ctx, sctx, stmtAst.Stmt, is) + sctx.GetSessionVars().StmtCtx.InPreparedPlanBuilding = false + if err != nil { + return nil, nil, err + } + err = tryCachePointPlan(ctx, sctx, stmt, is, p) + if err != nil { + return nil, nil, err + } + + // check whether this plan is cacheable. + if stmtCtx.UseCache { + if cacheable, reason := isPlanCacheable(sctx, p, len(matchOpts.ParamTypes), len(matchOpts.LimitOffsetAndCount), matchOpts.HasSubQuery); !cacheable { + stmtCtx.SetSkipPlanCache(errors.Errorf(reason)) + } + } + + // put this plan into the plan cache. + if stmtCtx.UseCache { + // rebuild key to exclude kv.TiFlash when stmt is not read only + if _, isolationReadContainTiFlash := sessVars.IsolationReadEngines[kv.TiFlash]; isolationReadContainTiFlash && !IsReadOnly(stmtAst.Stmt, sessVars) { + delete(sessVars.IsolationReadEngines, kv.TiFlash) + if cacheKey, err = NewPlanCacheKey(sessVars, stmt.StmtText, stmt.StmtDB, + stmtAst.SchemaVersion, latestSchemaVersion, bindSQL, expression.ExprPushDownBlackListReloadTimeStamp.Load()); err != nil { + return nil, nil, err + } + sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{} + } + cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, matchOpts, &stmtCtx.StmtHints) + stmt.NormalizedPlan, stmt.PlanDigest = NormalizePlan(p) + stmtCtx.SetPlan(p) + stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest) + sctx.GetSessionPlanCache().Put(cacheKey, cached, matchOpts) + } + sessVars.FoundInPlanCache = false + return p, names, err +} + +// RebuildPlan4CachedPlan will rebuild this plan under current user parameters. +func RebuildPlan4CachedPlan(p Plan) (ok bool) { + sc := p.SCtx().GetSessionVars().StmtCtx + if !sc.UseCache { + return false // plan-cache is disabled for this query + } + + sc.InPreparedPlanBuilding = true + defer func() { sc.InPreparedPlanBuilding = false }() + if err := rebuildRange(p); err != nil { + // TODO: log or warn this error. + return false // fail to rebuild ranges + } + if !sc.UseCache { + // in this case, the UseCache flag changes from `true` to `false`, then there must be some + // over-optimized operations were triggered, return `false` for safety here. + return false + } + return true +} + +func updateRange(p PhysicalPlan, ranges ranger.Ranges, rangeInfo string) { + switch x := p.(type) { + case *PhysicalTableScan: + x.Ranges = ranges + x.rangeInfo = rangeInfo + case *PhysicalIndexScan: + x.Ranges = ranges + x.rangeInfo = rangeInfo + case *PhysicalTableReader: + updateRange(x.TablePlans[0], ranges, rangeInfo) + case *PhysicalIndexReader: + updateRange(x.IndexPlans[0], ranges, rangeInfo) + case *PhysicalIndexLookUpReader: + updateRange(x.IndexPlans[0], ranges, rangeInfo) + } +} + +// rebuildRange doesn't set mem limit for building ranges. There are two reasons why we don't restrict range mem usage here. +// 1. The cached plan must be able to build complete ranges under mem limit when it is generated. Hence we can just build +// ranges from x.AccessConditions. The only difference between the last ranges and new ranges is the change of parameter +// values, which doesn't cause much change on the mem usage of complete ranges. +// 2. Different parameter values can change the mem usage of complete ranges. If we set range mem limit here, range fallback +// may heppen and cause correctness problem. For example, a in (?, ?, ?) is the access condition. When the plan is firstly +// generated, its complete ranges are ['a','a'], ['b','b'], ['c','c'], whose mem usage is under range mem limit 100B. +// When the cached plan is hit, the complete ranges may become ['aaa','aaa'], ['bbb','bbb'], ['ccc','ccc'], whose mem +// usage exceeds range mem limit 100B, and range fallback happens and tidb may fetch more rows than users expect. +func rebuildRange(p Plan) error { + sctx := p.SCtx() + sc := p.SCtx().GetSessionVars().StmtCtx + var err error + switch x := p.(type) { + case *PhysicalIndexHashJoin: + return rebuildRange(&x.PhysicalIndexJoin) + case *PhysicalIndexMergeJoin: + return rebuildRange(&x.PhysicalIndexJoin) + case *PhysicalIndexJoin: + if err := x.Ranges.Rebuild(); err != nil { + return err + } + if mutableRange, ok := x.Ranges.(*mutableIndexJoinRange); ok { + helper := mutableRange.buildHelper + rangeInfo := helper.buildRangeDecidedByInformation(helper.chosenPath.IdxCols, mutableRange.outerJoinKeys) + innerPlan := x.Children()[x.InnerChildIdx] + updateRange(innerPlan, x.Ranges.Range(), rangeInfo) + } + for _, child := range x.Children() { + err = rebuildRange(child) + if err != nil { + return err + } + } + case *PhysicalTableScan: + err = buildRangeForTableScan(sctx, x) + if err != nil { + return err + } + case *PhysicalIndexScan: + err = buildRangeForIndexScan(sctx, x) + if err != nil { + return err + } + case *PhysicalTableReader: + err = rebuildRange(x.TablePlans[0]) + if err != nil { + return err + } + case *PhysicalIndexReader: + err = rebuildRange(x.IndexPlans[0]) + if err != nil { + return err + } + case *PhysicalIndexLookUpReader: + err = rebuildRange(x.IndexPlans[0]) + if err != nil { + return err + } + case *PointGetPlan: + // if access condition is not nil, which means it's a point get generated by cbo. + if x.AccessConditions != nil { + if x.IndexInfo != nil { + ranges, err := ranger.DetachCondAndBuildRangeForIndex(x.ctx, x.AccessConditions, x.IdxCols, x.IdxColLens, 0) + if err != nil { + return err + } + if len(ranges.Ranges) != 1 || !isSafeRange(x.AccessConditions, ranges, false, nil) { + return errors.New("rebuild to get an unsafe range") + } + for i := range x.IndexValues { + x.IndexValues[i] = ranges.Ranges[0].LowVal[i] + } + } else { + var pkCol *expression.Column + var unsignedIntHandle bool + if x.TblInfo.PKIsHandle { + if pkColInfo := x.TblInfo.GetPkColInfo(); pkColInfo != nil { + pkCol = expression.ColInfo2Col(x.schema.Columns, pkColInfo) + } + if !x.TblInfo.IsCommonHandle { + unsignedIntHandle = true + } + } + if pkCol != nil { + ranges, accessConds, remainingConds, err := ranger.BuildTableRange(x.AccessConditions, x.ctx, pkCol.RetType, 0) + if err != nil { + return err + } + if len(ranges) != 1 || !isSafeRange(x.AccessConditions, &ranger.DetachRangeResult{ + Ranges: ranges, + AccessConds: accessConds, + RemainedConds: remainingConds, + }, unsignedIntHandle, nil) { + return errors.New("rebuild to get an unsafe range") + } + x.Handle = kv.IntHandle(ranges[0].LowVal[0].GetInt64()) + } + } + } + // The code should never run here as long as we're not using point get for partition table. + // And if we change the logic one day, here work as defensive programming to cache the error. + if x.PartitionInfo != nil { + // TODO: relocate the partition after rebuilding range to make PlanCache support PointGet + return errors.New("point get for partition table can not use plan cache") + } + if x.HandleConstant != nil { + dVal, err := convertConstant2Datum(sc, x.HandleConstant, x.handleFieldType) + if err != nil { + return err + } + iv, err := dVal.ToInt64(sc) + if err != nil { + return err + } + x.Handle = kv.IntHandle(iv) + return nil + } + for i, param := range x.IndexConstants { + if param != nil { + dVal, err := convertConstant2Datum(sc, param, x.ColsFieldType[i]) + if err != nil { + return err + } + x.IndexValues[i] = *dVal + } + } + return nil + case *BatchPointGetPlan: + // if access condition is not nil, which means it's a point get generated by cbo. + if x.AccessConditions != nil { + if x.IndexInfo != nil { + ranges, err := ranger.DetachCondAndBuildRangeForIndex(x.ctx, x.AccessConditions, x.IdxCols, x.IdxColLens, 0) + if err != nil { + return err + } + if len(ranges.Ranges) != len(x.IndexValues) || !isSafeRange(x.AccessConditions, ranges, false, nil) { + return errors.New("rebuild to get an unsafe range") + } + for i := range x.IndexValues { + copy(x.IndexValues[i], ranges.Ranges[i].LowVal) + } + } else { + var pkCol *expression.Column + var unsignedIntHandle bool + if x.TblInfo.PKIsHandle { + if pkColInfo := x.TblInfo.GetPkColInfo(); pkColInfo != nil { + pkCol = expression.ColInfo2Col(x.schema.Columns, pkColInfo) + } + if !x.TblInfo.IsCommonHandle { + unsignedIntHandle = true + } + } + if pkCol != nil { + ranges, accessConds, remainingConds, err := ranger.BuildTableRange(x.AccessConditions, x.ctx, pkCol.RetType, 0) + if err != nil { + return err + } + if len(ranges) != len(x.Handles) || !isSafeRange(x.AccessConditions, &ranger.DetachRangeResult{ + Ranges: ranges, + AccessConds: accessConds, + RemainedConds: remainingConds, + }, unsignedIntHandle, nil) { + return errors.New("rebuild to get an unsafe range") + } + for i := range ranges { + x.Handles[i] = kv.IntHandle(ranges[i].LowVal[0].GetInt64()) + } + } + } + } + for i, param := range x.HandleParams { + if param != nil { + dVal, err := convertConstant2Datum(sc, param, x.HandleType) + if err != nil { + return err + } + iv, err := dVal.ToInt64(sc) + if err != nil { + return err + } + x.Handles[i] = kv.IntHandle(iv) + } + } + for i, params := range x.IndexValueParams { + if len(params) < 1 { + continue + } + for j, param := range params { + if param != nil { + dVal, err := convertConstant2Datum(sc, param, x.IndexColTypes[j]) + if err != nil { + return err + } + x.IndexValues[i][j] = *dVal + } + } + } + case *PhysicalIndexMergeReader: + indexMerge := p.(*PhysicalIndexMergeReader) + for _, partialPlans := range indexMerge.PartialPlans { + err = rebuildRange(partialPlans[0]) + if err != nil { + return err + } + } + // We don't need to handle the indexMerge.TablePlans, because the tablePlans + // only can be (Selection) + TableRowIDScan. There have no range need to rebuild. + case PhysicalPlan: + for _, child := range x.Children() { + err = rebuildRange(child) + if err != nil { + return err + } + } + case *Insert: + if x.SelectPlan != nil { + return rebuildRange(x.SelectPlan) + } + case *Update: + if x.SelectPlan != nil { + return rebuildRange(x.SelectPlan) + } + case *Delete: + if x.SelectPlan != nil { + return rebuildRange(x.SelectPlan) + } + } + return nil +} + +func convertConstant2Datum(sc *stmtctx.StatementContext, con *expression.Constant, target *types.FieldType) (*types.Datum, error) { + val, err := con.Eval(chunk.Row{}) + if err != nil { + return nil, err + } + dVal, err := val.ConvertTo(sc, target) + if err != nil { + return nil, err + } + // The converted result must be same as original datum. + cmp, err := dVal.Compare(sc, &val, collate.GetCollator(target.GetCollate())) + if err != nil || cmp != 0 { + return nil, errors.New("Convert constant to datum is failed, because the constant has changed after the covert") + } + return &dVal, nil +} + +func buildRangeForTableScan(sctx sessionctx.Context, ts *PhysicalTableScan) (err error) { + if ts.Table.IsCommonHandle { + pk := tables.FindPrimaryIndex(ts.Table) + pkCols := make([]*expression.Column, 0, len(pk.Columns)) + pkColsLen := make([]int, 0, len(pk.Columns)) + for _, colInfo := range pk.Columns { + if pkCol := expression.ColInfo2Col(ts.schema.Columns, ts.Table.Columns[colInfo.Offset]); pkCol != nil { + pkCols = append(pkCols, pkCol) + // We need to consider the prefix index. + // For example: when we have 'a varchar(50), index idx(a(10))' + // So we will get 'colInfo.Length = 50' and 'pkCol.RetType.flen = 10'. + // In 'hasPrefix' function from 'util/ranger/ranger.go' file, + // we use 'columnLength == types.UnspecifiedLength' to check whether we have prefix index. + if colInfo.Length != types.UnspecifiedLength && colInfo.Length == pkCol.RetType.GetFlen() { + pkColsLen = append(pkColsLen, types.UnspecifiedLength) + } else { + pkColsLen = append(pkColsLen, colInfo.Length) + } + } + } + if len(pkCols) > 0 { + res, err := ranger.DetachCondAndBuildRangeForIndex(sctx, ts.AccessCondition, pkCols, pkColsLen, 0) + if err != nil { + return err + } + if !isSafeRange(ts.AccessCondition, res, false, ts.Ranges) { + return errors.New("rebuild to get an unsafe range") + } + ts.Ranges = res.Ranges + } else { + if len(ts.AccessCondition) > 0 { + return errors.New("fail to build ranges, cannot get the primary key column") + } + ts.Ranges = ranger.FullRange() + } + } else { + var pkCol *expression.Column + if ts.Table.PKIsHandle { + if pkColInfo := ts.Table.GetPkColInfo(); pkColInfo != nil { + pkCol = expression.ColInfo2Col(ts.schema.Columns, pkColInfo) + } + } + if pkCol != nil { + ranges, accessConds, remainingConds, err := ranger.BuildTableRange(ts.AccessCondition, sctx, pkCol.RetType, 0) + if err != nil { + return err + } + if !isSafeRange(ts.AccessCondition, &ranger.DetachRangeResult{ + Ranges: ts.Ranges, + AccessConds: accessConds, + RemainedConds: remainingConds, + }, true, ts.Ranges) { + return errors.New("rebuild to get an unsafe range") + } + ts.Ranges = ranges + } else { + if len(ts.AccessCondition) > 0 { + return errors.New("fail to build ranges, cannot get the primary key column") + } + ts.Ranges = ranger.FullIntRange(false) + } + } + return +} + +func buildRangeForIndexScan(sctx sessionctx.Context, is *PhysicalIndexScan) (err error) { + if len(is.IdxCols) == 0 { + if ranger.HasFullRange(is.Ranges, false) { // the original range is already a full-range. + is.Ranges = ranger.FullRange() + return + } + return errors.New("unexpected range for PhysicalIndexScan") + } + + res, err := ranger.DetachCondAndBuildRangeForIndex(sctx, is.AccessCondition, is.IdxCols, is.IdxColLens, 0) + if err != nil { + return err + } + if !isSafeRange(is.AccessCondition, res, false, is.Ranges) { + return errors.New("rebuild to get an unsafe range") + } + is.Ranges = res.Ranges + return +} + +// checkRebuiltRange checks whether the re-built range is safe. +// To re-use a cached plan, the planner needs to rebuild the access range, but as +// parameters change, some unsafe ranges may occur. +// For example, the first time the planner can build a range `(2, 5)` from `a>2 and a<(?)5`, but if the +// parameter changes to `(?)1`, then it'll get an unsafe range `(empty)`. +// To make plan-cache safer, let the planner abandon the cached plan if it gets an unsafe range here. +func isSafeRange(accessConds []expression.Expression, rebuiltResult *ranger.DetachRangeResult, + unsignedIntHandle bool, originalRange ranger.Ranges) (safe bool) { + if len(rebuiltResult.RemainedConds) > 0 || // the ranger generates some other extra conditions + len(rebuiltResult.AccessConds) != len(accessConds) || // not all access conditions are used + len(rebuiltResult.Ranges) == 0 { // get an empty range + return false + } + + if len(accessConds) > 0 && // if have accessConds, and + ranger.HasFullRange(rebuiltResult.Ranges, unsignedIntHandle) && // get an full range, and + originalRange != nil && !ranger.HasFullRange(originalRange, unsignedIntHandle) { // the original range is not a full range + return false + } + + return true +} + +// CheckPreparedPriv checks the privilege of the prepared statement +func CheckPreparedPriv(sctx sessionctx.Context, stmt *PlanCacheStmt, is infoschema.InfoSchema) error { + if pm := privilege.GetPrivilegeManager(sctx); pm != nil { + visitInfo := VisitInfo4PrivCheck(is, stmt.PreparedAst.Stmt, stmt.VisitInfos) + if err := CheckPrivilege(sctx.GetSessionVars().ActiveRoles, pm, visitInfo); err != nil { + return err + } + } + err := CheckTableLock(sctx, is, stmt.VisitInfos) + return err +} + +// tryCachePointPlan will try to cache point execution plan, there may be some +// short paths for these executions, currently "point select" and "point update" +func tryCachePointPlan(_ context.Context, sctx sessionctx.Context, + stmt *PlanCacheStmt, _ infoschema.InfoSchema, p Plan) error { + if !sctx.GetSessionVars().StmtCtx.UseCache { + return nil + } + var ( + stmtAst = stmt.PreparedAst + ok bool + err error + names types.NameSlice + ) + + if plan, _ok := p.(*PointGetPlan); _ok { + ok, err = IsPointGetWithPKOrUniqueKeyByAutoCommit(sctx, p) + names = p.OutputNames() + if err != nil { + return err + } + if ok { + plan.stmtHints = sctx.GetSessionVars().StmtCtx.StmtHints.Clone() + } + } + + if ok { + // just cache point plan now + stmtAst.CachedPlan = p + stmtAst.CachedNames = names + stmt.NormalizedPlan, stmt.PlanDigest = NormalizePlan(p) + sctx.GetSessionVars().StmtCtx.SetPlan(p) + sctx.GetSessionVars().StmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest) + } + return err +} + +// GetBindSQL4PlanCache used to get the bindSQL for plan cache to build the plan cache key. +func GetBindSQL4PlanCache(sctx sessionctx.Context, stmt *PlanCacheStmt) (string, bool) { + useBinding := sctx.GetSessionVars().UsePlanBaselines + ignore := false + if !useBinding || stmt.PreparedAst.Stmt == nil || stmt.NormalizedSQL4PC == "" || stmt.SQLDigest4PC == "" { + return "", ignore + } + if sctx.Value(bindinfo.SessionBindInfoKeyType) == nil { + return "", ignore + } + sessionHandle := sctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) + bindRecord := sessionHandle.GetBindRecord(stmt.SQLDigest4PC, stmt.NormalizedSQL4PC, "") + if bindRecord != nil { + enabledBinding := bindRecord.FindEnabledBinding() + if enabledBinding != nil { + ignore = enabledBinding.Hint.ContainTableHint(HintIgnorePlanCache) + return enabledBinding.BindSQL, ignore + } + } + globalHandle := domain.GetDomain(sctx).BindHandle() + if globalHandle == nil { + return "", ignore + } + bindRecord = globalHandle.GetBindRecord(stmt.SQLDigest4PC, stmt.NormalizedSQL4PC, "") + if bindRecord != nil { + enabledBinding := bindRecord.FindEnabledBinding() + if enabledBinding != nil { + ignore = enabledBinding.Hint.ContainTableHint(HintIgnorePlanCache) + return enabledBinding.BindSQL, ignore + } + } + return "", ignore +} + +// IsPointPlanShortPathOK check if we can execute using plan cached in prepared structure +// Be careful with the short path, current precondition is ths cached plan satisfying +// IsPointGetWithPKOrUniqueKeyByAutoCommit +func IsPointPlanShortPathOK(sctx sessionctx.Context, is infoschema.InfoSchema, stmt *PlanCacheStmt) (bool, error) { + stmtAst := stmt.PreparedAst + if stmtAst.CachedPlan == nil || staleread.IsStmtStaleness(sctx) { + return false, nil + } + // check auto commit + if !IsAutoCommitTxn(sctx) { + return false, nil + } + if stmtAst.SchemaVersion != is.SchemaMetaVersion() { + stmtAst.CachedPlan = nil + stmt.ColumnInfos = nil + return false, nil + } + // maybe we'd better check cached plan type here, current + // only point select/update will be cached, see "getPhysicalPlan" func + var ok bool + var err error + switch stmtAst.CachedPlan.(type) { + case *PointGetPlan: + ok = true + case *Update: + pointUpdate := stmtAst.CachedPlan.(*Update) + _, ok = pointUpdate.SelectPlan.(*PointGetPlan) + if !ok { + err = errors.Errorf("cached update plan not point update") + stmtAst.CachedPlan = nil + return false, err + } + default: + ok = false + } + return ok, err +} diff --git a/planner/core/plan_cache_utils.go b/planner/core/plan_cache_utils.go new file mode 100644 index 0000000000000..e5e83befbcd77 --- /dev/null +++ b/planner/core/plan_cache_utils.go @@ -0,0 +1,665 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "cmp" + "context" + "math" + "slices" + "strconv" + "time" + "unsafe" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/planner/util" + "github.com/pingcap/tidb/planner/util/fixcontrol" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/statistics" + "github.com/pingcap/tidb/types" + driver "github.com/pingcap/tidb/types/parser_driver" + "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/hack" + "github.com/pingcap/tidb/util/hint" + "github.com/pingcap/tidb/util/intest" + "github.com/pingcap/tidb/util/kvcache" + utilpc "github.com/pingcap/tidb/util/plancache" + "github.com/pingcap/tidb/util/size" + atomic2 "go.uber.org/atomic" +) + +const ( + // MaxCacheableLimitCount is the max limit count for cacheable query. + MaxCacheableLimitCount = 10000 +) + +var ( + // PreparedPlanCacheMaxMemory stores the max memory size defined in the global config "performance-server-memory-quota". + PreparedPlanCacheMaxMemory = *atomic2.NewUint64(math.MaxUint64) + + // ExtractSelectAndNormalizeDigest extract the select statement and normalize it. + ExtractSelectAndNormalizeDigest func(stmtNode ast.StmtNode, specifiledDB string) (ast.StmtNode, string, string, error) +) + +type paramMarkerExtractor struct { + markers []ast.ParamMarkerExpr +} + +func (*paramMarkerExtractor) Enter(in ast.Node) (ast.Node, bool) { + return in, false +} + +func (e *paramMarkerExtractor) Leave(in ast.Node) (ast.Node, bool) { + if x, ok := in.(*driver.ParamMarkerExpr); ok { + e.markers = append(e.markers, x) + } + return in, true +} + +// GeneratePlanCacheStmtWithAST generates the PlanCacheStmt structure for this AST. +// paramSQL is the corresponding parameterized sql like 'select * from t where a?'. +// paramStmt is the Node of paramSQL. +func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, isPrepStmt bool, + paramSQL string, paramStmt ast.StmtNode, is infoschema.InfoSchema) (*PlanCacheStmt, Plan, int, error) { + vars := sctx.GetSessionVars() + var extractor paramMarkerExtractor + paramStmt.Accept(&extractor) + + // DDL Statements can not accept parameters + if _, ok := paramStmt.(ast.DDLNode); ok && len(extractor.markers) > 0 { + return nil, nil, 0, ErrPrepareDDL + } + + switch paramStmt.(type) { + case *ast.ImportIntoStmt, *ast.LoadDataStmt, *ast.PrepareStmt, *ast.ExecuteStmt, *ast.DeallocateStmt, *ast.NonTransactionalDMLStmt: + return nil, nil, 0, ErrUnsupportedPs + } + + // Prepare parameters should NOT over 2 bytes(MaxUint16) + // https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK. + if len(extractor.markers) > math.MaxUint16 { + return nil, nil, 0, ErrPsManyParam + } + + ret := &PreprocessorReturn{InfoSchema: is} // is can be nil, and + err := Preprocess(ctx, sctx, paramStmt, InPrepare, WithPreprocessorReturn(ret)) + if err != nil { + return nil, nil, 0, err + } + + // The parameter markers are appended in visiting order, which may not + // be the same as the position order in the query string. We need to + // sort it by position. + slices.SortFunc(extractor.markers, func(i, j ast.ParamMarkerExpr) int { + return cmp.Compare(i.(*driver.ParamMarkerExpr).Offset, j.(*driver.ParamMarkerExpr).Offset) + }) + paramCount := len(extractor.markers) + for i := 0; i < paramCount; i++ { + extractor.markers[i].SetOrder(i) + } + + prepared := &ast.Prepared{ + Stmt: paramStmt, + StmtType: ast.GetStmtLabel(paramStmt), + Params: extractor.markers, + SchemaVersion: ret.InfoSchema.SchemaMetaVersion(), + } + normalizedSQL, digest := parser.NormalizeDigest(prepared.Stmt.Text()) + + var ( + normalizedSQL4PC, digest4PC string + selectStmtNode ast.StmtNode + cacheable bool + reason string + ) + if (isPrepStmt && !vars.EnablePreparedPlanCache) || // prepared statement + (!isPrepStmt && !vars.EnableNonPreparedPlanCache) { // non-prepared statement + cacheable = false + reason = "plan cache is disabled" + } else { + if isPrepStmt { + cacheable, reason = CacheableWithCtx(sctx, paramStmt, ret.InfoSchema) + } else { + cacheable = true // it is already checked here + } + if !cacheable { + sctx.GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("skip prepared plan-cache: " + reason)) + } + selectStmtNode, normalizedSQL4PC, digest4PC, err = ExtractSelectAndNormalizeDigest(paramStmt, vars.CurrentDB) + if err != nil || selectStmtNode == nil { + normalizedSQL4PC = "" + digest4PC = "" + } + } + + // For prepared statements like `prepare st from 'select * from t where a 0 { + return key.memoryUsage + } + sum = emptyPlanCacheKeySize + int64(len(key.database)+len(key.stmtText)+len(key.bindSQL)) + + int64(len(key.isolationReadEngines))*size.SizeOfUint8 + int64(cap(key.hash)) + key.memoryUsage = sum + return +} + +// SetPstmtIDSchemaVersion implements PstmtCacheKeyMutator interface to change pstmtID and schemaVersion of cacheKey. +// so we can reuse Key instead of new every time. +func SetPstmtIDSchemaVersion(key kvcache.Key, stmtText string, schemaVersion int64, isolationReadEngines map[kv.StoreType]struct{}) { + psStmtKey, isPsStmtKey := key.(*planCacheKey) + if !isPsStmtKey { + return + } + psStmtKey.stmtText = stmtText + psStmtKey.schemaVersion = schemaVersion + psStmtKey.isolationReadEngines = make(map[kv.StoreType]struct{}) + for k, v := range isolationReadEngines { + psStmtKey.isolationReadEngines[k] = v + } + psStmtKey.hash = psStmtKey.hash[:0] +} + +// NewPlanCacheKey creates a new planCacheKey object. +// Note: lastUpdatedSchemaVersion will only be set in the case of rc or for update read in order to +// differentiate the cache key. In other cases, it will be 0. +func NewPlanCacheKey(sessionVars *variable.SessionVars, stmtText, stmtDB string, schemaVersion int64, + lastUpdatedSchemaVersion int64, bindSQL string, exprBlacklistTS int64) (kvcache.Key, error) { + if stmtText == "" { + return nil, errors.New("no statement text") + } + if schemaVersion == 0 && !intest.InTest { + return nil, errors.New("Schema version uninitialized") + } + if stmtDB == "" { + stmtDB = sessionVars.CurrentDB + } + timezoneOffset := 0 + if sessionVars.TimeZone != nil { + _, timezoneOffset = time.Now().In(sessionVars.TimeZone).Zone() + } + key := &planCacheKey{ + database: stmtDB, + connID: sessionVars.ConnectionID, + stmtText: stmtText, + schemaVersion: schemaVersion, + lastUpdatedSchemaVersion: lastUpdatedSchemaVersion, + sqlMode: sessionVars.SQLMode, + timezoneOffset: timezoneOffset, + isolationReadEngines: make(map[kv.StoreType]struct{}), + selectLimit: sessionVars.SelectLimit, + bindSQL: bindSQL, + inRestrictedSQL: sessionVars.InRestrictedSQL, + restrictedReadOnly: variable.RestrictedReadOnly.Load(), + TiDBSuperReadOnly: variable.VarTiDBSuperReadOnly.Load(), + ExprBlacklistTS: exprBlacklistTS, + } + for k, v := range sessionVars.IsolationReadEngines { + key.isolationReadEngines[k] = v + } + return key, nil +} + +// PlanCacheValue stores the cached Statement and StmtNode. +type PlanCacheValue struct { + Plan Plan + OutPutNames []*types.FieldName + TblInfo2UnionScan map[*model.TableInfo]bool + memoryUsage int64 + + // matchOpts stores some fields help to choose a suitable plan + matchOpts *utilpc.PlanCacheMatchOpts + // stmtHints stores the hints which set session variables, because the hints won't be processed using cached plan. + stmtHints *stmtctx.StmtHints +} + +// unKnownMemoryUsage represent the memory usage of uncounted structure, maybe need implement later +// 100 KiB is approximate consumption of a plan from our internal tests +const unKnownMemoryUsage = int64(50 * size.KB) + +// MemoryUsage return the memory usage of PlanCacheValue +func (v *PlanCacheValue) MemoryUsage() (sum int64) { + if v == nil { + return + } + + if v.memoryUsage > 0 { + return v.memoryUsage + } + switch x := v.Plan.(type) { + case PhysicalPlan: + sum = x.MemoryUsage() + case *Insert: + sum = x.MemoryUsage() + case *Update: + sum = x.MemoryUsage() + case *Delete: + sum = x.MemoryUsage() + default: + sum = unKnownMemoryUsage + } + + sum += size.SizeOfInterface + size.SizeOfSlice*2 + int64(cap(v.OutPutNames))*size.SizeOfPointer + + size.SizeOfMap + int64(len(v.TblInfo2UnionScan))*(size.SizeOfPointer+size.SizeOfBool) + size.SizeOfInt64*2 + if v.matchOpts != nil { + sum += int64(cap(v.matchOpts.ParamTypes)) * size.SizeOfPointer + for _, ft := range v.matchOpts.ParamTypes { + sum += ft.MemoryUsage() + } + } + + for _, name := range v.OutPutNames { + sum += name.MemoryUsage() + } + v.memoryUsage = sum + return +} + +// NewPlanCacheValue creates a SQLCacheValue. +func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.TableInfo]bool, + matchOpts *utilpc.PlanCacheMatchOpts, stmtHints *stmtctx.StmtHints) *PlanCacheValue { + dstMap := make(map[*model.TableInfo]bool) + for k, v := range srcMap { + dstMap[k] = v + } + userParamTypes := make([]*types.FieldType, len(matchOpts.ParamTypes)) + for i, tp := range matchOpts.ParamTypes { + userParamTypes[i] = tp.Clone() + } + return &PlanCacheValue{ + Plan: plan, + OutPutNames: names, + TblInfo2UnionScan: dstMap, + matchOpts: matchOpts, + stmtHints: stmtHints.Clone(), + } +} + +// PlanCacheQueryFeatures records all query features which may affect plan selection. +type PlanCacheQueryFeatures struct { + limits []*ast.Limit + hasSubquery bool + tables []*ast.TableName // to capture table stats changes +} + +// Enter implements Visitor interface. +func (f *PlanCacheQueryFeatures) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + switch node := in.(type) { + case *ast.Limit: + f.limits = append(f.limits, node) + case *ast.SubqueryExpr, *ast.ExistsSubqueryExpr: + f.hasSubquery = true + case *ast.TableName: + f.tables = append(f.tables, node) + } + return in, false +} + +// Leave implements Visitor interface. +func (*PlanCacheQueryFeatures) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, true +} + +// PlanCacheStmt store prepared ast from PrepareExec and other related fields +type PlanCacheStmt struct { + PreparedAst *ast.Prepared + StmtDB string // which DB the statement will be processed over + VisitInfos []visitInfo + ColumnInfos interface{} + // Executor is only used for point get scene. + // Notice that we should only cache the PointGetExecutor that have a snapshot with MaxTS in it. + // If the current plan is not PointGet or does not use MaxTS optimization, this value should be nil here. + Executor interface{} + + StmtCacheable bool // Whether this stmt is cacheable. + UncacheableReason string // Why this stmt is uncacheable. + QueryFeatures *PlanCacheQueryFeatures + + NormalizedSQL string + NormalizedPlan string + SQLDigest *parser.Digest + PlanDigest *parser.Digest + ForUpdateRead bool + SnapshotTSEvaluator func(sessionctx.Context) (uint64, error) + NormalizedSQL4PC string + SQLDigest4PC string + + // the different between NormalizedSQL, NormalizedSQL4PC and StmtText: + // for the query `select * from t where a>1 and b ? and `b` < ? --> constants are normalized to '?', + // NormalizedSQL4PC: select * from `test` . `t` where `a` > ? and `b` < ? --> schema name is added, + // StmtText: select * from t where a>1 and b just format the original query; + StmtText string +} + +// GetPreparedStmt extract the prepared statement from the execute statement. +func GetPreparedStmt(stmt *ast.ExecuteStmt, vars *variable.SessionVars) (*PlanCacheStmt, error) { + if stmt.PrepStmt != nil { + return stmt.PrepStmt.(*PlanCacheStmt), nil + } + if stmt.Name != "" { + prepStmt, err := vars.GetPreparedStmtByName(stmt.Name) + if err != nil { + return nil, err + } + stmt.PrepStmt = prepStmt + return prepStmt.(*PlanCacheStmt), nil + } + return nil, ErrStmtNotFound +} + +func tableStatsVersionForPlanCache(tStats *statistics.Table) (tableStatsVer uint64) { + if tStats == nil { + return 0 + } + // use the max version of all columns and indices as the table stats version + for _, col := range tStats.Columns { + if col.LastUpdateVersion > tableStatsVer { + tableStatsVer = col.LastUpdateVersion + } + } + for _, idx := range tStats.Indices { + if idx.LastUpdateVersion > tableStatsVer { + tableStatsVer = idx.LastUpdateVersion + } + } + return tableStatsVer +} + +// GetMatchOpts get options to fetch plan or generate new plan +// we can add more options here +func GetMatchOpts(sctx sessionctx.Context, is infoschema.InfoSchema, stmt *PlanCacheStmt, params []expression.Expression) (*utilpc.PlanCacheMatchOpts, error) { + var statsVerHash uint64 + var limitOffsetAndCount []uint64 + + if stmt.QueryFeatures != nil { + for _, node := range stmt.QueryFeatures.tables { + t, err := is.TableByName(node.Schema, node.Name) + if err != nil { // CTE in this case + continue + } + tStats := getStatsTable(sctx, t.Meta(), t.Meta().ID) + statsVerHash += tableStatsVersionForPlanCache(tStats) // use '+' as the hash function for simplicity + } + + for _, node := range stmt.QueryFeatures.limits { + if node.Count != nil { + if count, isParamMarker := node.Count.(*driver.ParamMarkerExpr); isParamMarker { + typeExpected, val := CheckParamTypeInt64orUint64(count) + if !typeExpected { + sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.New("unexpected value after LIMIT")) + break + } + if val > MaxCacheableLimitCount { + sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.New("limit count is too large")) + break + } + limitOffsetAndCount = append(limitOffsetAndCount, val) + } + } + if node.Offset != nil { + if offset, isParamMarker := node.Offset.(*driver.ParamMarkerExpr); isParamMarker { + typeExpected, val := CheckParamTypeInt64orUint64(offset) + if !typeExpected { + sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.New("unexpected value after LIMIT")) + break + } + limitOffsetAndCount = append(limitOffsetAndCount, val) + } + } + } + } + + return &utilpc.PlanCacheMatchOpts{ + LimitOffsetAndCount: limitOffsetAndCount, + HasSubQuery: stmt.QueryFeatures.hasSubquery, + StatsVersionHash: statsVerHash, + ParamTypes: parseParamTypes(sctx, params), + ForeignKeyChecks: sctx.GetSessionVars().ForeignKeyChecks, + }, nil +} + +// CheckTypesCompatibility4PC compares FieldSlice with []*types.FieldType +// Currently this is only used in plan cache to check whether the types of parameters are compatible. +// If the types of parameters are compatible, we can use the cached plan. +// tpsExpected is types from cached plan +func checkTypesCompatibility4PC(tpsExpected, tpsActual []*types.FieldType) bool { + if len(tpsExpected) != len(tpsActual) { + return false + } + for i := range tpsActual { + // We only use part of logic of `func (ft *FieldType) Equal(other *FieldType)` here because (1) only numeric and + // string types will show up here, and (2) we don't need flen and decimal to be matched exactly to use plan cache + tpEqual := (tpsExpected[i].GetType() == tpsActual[i].GetType()) || + (tpsExpected[i].GetType() == mysql.TypeVarchar && tpsActual[i].GetType() == mysql.TypeVarString) || + (tpsExpected[i].GetType() == mysql.TypeVarString && tpsActual[i].GetType() == mysql.TypeVarchar) + if !tpEqual || tpsExpected[i].GetCharset() != tpsActual[i].GetCharset() || tpsExpected[i].GetCollate() != tpsActual[i].GetCollate() || + (tpsExpected[i].EvalType() == types.ETInt && mysql.HasUnsignedFlag(tpsExpected[i].GetFlag()) != mysql.HasUnsignedFlag(tpsActual[i].GetFlag())) { + return false + } + // When the type is decimal, we should compare the Flen and Decimal. + // We can only use the plan when both Flen and Decimal should less equal than the cached one. + // We assume here that there is no correctness problem when the precision of the parameters is less than the precision of the parameters in the cache. + if tpEqual && tpsExpected[i].GetType() == mysql.TypeNewDecimal && !(tpsExpected[i].GetFlen() >= tpsActual[i].GetFlen() && tpsExpected[i].GetDecimal() >= tpsActual[i].GetDecimal()) { + return false + } + } + return true +} + +func isSafePointGetPath4PlanCache(sctx sessionctx.Context, path *util.AccessPath) bool { + // PointGet might contain some over-optimized assumptions, like `a>=1 and a<=1` --> `a=1`, but + // these assumptions may be broken after parameters change. + + if isSafePointGetPath4PlanCacheScenario1(path) { + return true + } + + // TODO: enable this fix control switch by default after more test cases are added. + if sctx != nil && sctx.GetSessionVars() != nil && sctx.GetSessionVars().OptimizerFixControl != nil { + fixControlOK := fixcontrol.GetBoolWithDefault(sctx.GetSessionVars().GetOptimizerFixControlMap(), fixcontrol.Fix44830, false) + if fixControlOK && (isSafePointGetPath4PlanCacheScenario2(path) || isSafePointGetPath4PlanCacheScenario3(path)) { + return true + } + } + + return false +} + +func isSafePointGetPath4PlanCacheScenario1(path *util.AccessPath) bool { + // safe scenario 1: each column corresponds to a single EQ, `a=1 and b=2 and c=3` --> `[1, 2, 3]` + if len(path.Ranges) <= 0 || path.Ranges[0].Width() != len(path.AccessConds) { + return false + } + for _, accessCond := range path.AccessConds { + f, ok := accessCond.(*expression.ScalarFunction) + if !ok || f.FuncName.L != ast.EQ { // column = constant + return false + } + } + return true +} + +func isSafePointGetPath4PlanCacheScenario2(path *util.AccessPath) bool { + // safe scenario 2: this Batch or PointGet is simply from a single IN predicate, `key in (...)` + if len(path.Ranges) <= 0 || len(path.AccessConds) != 1 { + return false + } + f, ok := path.AccessConds[0].(*expression.ScalarFunction) + if !ok || f.FuncName.L != ast.In { + return false + } + return len(path.Ranges) == len(f.GetArgs())-1 // no duplicated values in this in-list for safety. +} + +func isSafePointGetPath4PlanCacheScenario3(path *util.AccessPath) bool { + // safe scenario 3: this Batch or PointGet is simply from a simple DNF like `key=? or key=? or key=?` + if len(path.Ranges) <= 0 || len(path.AccessConds) != 1 { + return false + } + f, ok := path.AccessConds[0].(*expression.ScalarFunction) + if !ok || f.FuncName.L != ast.LogicOr { + return false + } + + dnfExprs := expression.FlattenDNFConditions(f) + if len(path.Ranges) != len(dnfExprs) { + // no duplicated values in this in-list for safety. + // e.g. `k=1 or k=2 or k=1` --> [[1, 1], [2, 2]] + return false + } + + for _, expr := range dnfExprs { + f, ok := expr.(*expression.ScalarFunction) + if !ok { + return false + } + switch f.FuncName.L { + case ast.EQ: // (k=1 or k=2) --> [k=1, k=2] + case ast.LogicAnd: // ((k1=1 and k2=1) or (k1=2 and k2=2)) --> [k1=1 and k2=1, k2=2 and k2=2] + cnfExprs := expression.FlattenCNFConditions(f) + if path.Ranges[0].Width() != len(cnfExprs) { // not all key columns are specified + return false + } + for _, expr := range cnfExprs { // k1=1 and k2=1 + f, ok := expr.(*expression.ScalarFunction) + if !ok || f.FuncName.L != ast.EQ { + return false + } + } + default: + return false + } + } + return true +} diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 5d309606afc06..1bc57c32db40e 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -87,6 +87,33 @@ type PointGetPlan struct { planCost float64 // accessCols represents actual columns the PointGet will access, which are used to calculate row-size accessCols []*expression.Column +<<<<<<< HEAD +======= + + // probeParents records the IndexJoins and Applys with this operator in their inner children. + // Please see comments in PhysicalPlan for details. + probeParents []PhysicalPlan + // stmtHints should restore in executing context. + stmtHints *stmtctx.StmtHints +} + +func (p *PointGetPlan) getEstRowCountForDisplay() float64 { + if p == nil { + return 0 + } + return p.StatsInfo().RowCount * getEstimatedProbeCntFromProbeParents(p.probeParents) +} + +func (p *PointGetPlan) getActualProbeCnt(statsColl *execdetails.RuntimeStatsColl) int64 { + if p == nil { + return 1 + } + return getActualProbeCntFromProbeParents(p.probeParents, statsColl) +} + +func (p *PointGetPlan) setProbeParents(probeParents []PhysicalPlan) { + p.probeParents = probeParents +>>>>>>> c34f6fc83d6 (planner: store the hints of session variable (#45814)) } type nameValuePair struct { diff --git a/session/test/vars/BUILD.bazel b/session/test/vars/BUILD.bazel new file mode 100644 index 0000000000000..a53c88a43ed62 --- /dev/null +++ b/session/test/vars/BUILD.bazel @@ -0,0 +1,30 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + +go_test( + name = "vars_test", + timeout = "short", + srcs = [ + "main_test.go", + "vars_test.go", + ], + flaky = True, + shard_count = 13, + deps = [ + "//config", + "//domain", + "//errno", + "//kv", + "//parser/mysql", + "//parser/terror", + "//sessionctx/stmtctx", + "//sessionctx/variable", + "//testkit", + "//testkit/testmain", + "//testkit/testsetup", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_stretchr_testify//require", + "@com_github_tikv_client_go_v2//tikv", + "@com_github_tikv_client_go_v2//txnkv/transaction", + "@org_uber_go_goleak//:goleak", + ], +) diff --git a/session/test/vars/vars_test.go b/session/test/vars/vars_test.go new file mode 100644 index 0000000000000..d533c175a880f --- /dev/null +++ b/session/test/vars/vars_test.go @@ -0,0 +1,685 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vars + +import ( + "context" + "fmt" + "strconv" + "testing" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/errno" + tikv "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/txnkv/transaction" +) + +func TestKVVars(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set @@tidb_backoff_lock_fast = 1") + tk.MustExec("set @@tidb_backoff_weight = 100") + tk.MustExec("create table if not exists kvvars (a int key)") + tk.MustExec("insert into kvvars values (1)") + tk.MustExec("begin") + txn, err := tk.Session().Txn(false) + require.NoError(t, err) + vars := txn.GetVars().(*tikv.Variables) + require.Equal(t, 1, vars.BackoffLockFast) + require.Equal(t, 100, vars.BackOffWeight) + tk.MustExec("rollback") + tk.MustExec("set @@tidb_backoff_weight = 50") + tk.MustExec("set @@autocommit = 0") + tk.MustExec("select * from kvvars") + require.True(t, tk.Session().GetSessionVars().InTxn()) + txn, err = tk.Session().Txn(false) + require.NoError(t, err) + vars = txn.GetVars().(*tikv.Variables) + require.Equal(t, 50, vars.BackOffWeight) + + tk.MustExec("set @@autocommit = 1") + require.Nil(t, failpoint.Enable("tikvclient/probeSetVars", `return(true)`)) + tk.MustExec("select * from kvvars where a = 1") + require.Nil(t, failpoint.Disable("tikvclient/probeSetVars")) + require.True(t, transaction.SetSuccess.Load()) + transaction.SetSuccess.Store(false) +} + +func TestRemovedSysVars(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + variable.RegisterSysVar(&variable.SysVar{Scope: variable.ScopeGlobal | variable.ScopeSession, Name: "bogus_var", Value: "acdc"}) + result := tk.MustQuery("SHOW GLOBAL VARIABLES LIKE 'bogus_var'") + result.Check(testkit.Rows("bogus_var acdc")) + result = tk.MustQuery("SELECT @@GLOBAL.bogus_var") + result.Check(testkit.Rows("acdc")) + tk.MustExec("SET GLOBAL bogus_var = 'newvalue'") + + // unregister + variable.UnregisterSysVar("bogus_var") + + result = tk.MustQuery("SHOW GLOBAL VARIABLES LIKE 'bogus_var'") + result.Check(testkit.Rows()) // empty + tk.MustContainErrMsg("SET GLOBAL bogus_var = 'newvalue'", "[variable:1193]Unknown system variable 'bogus_var'") + tk.MustContainErrMsg("SELECT @@GLOBAL.bogus_var", "[variable:1193]Unknown system variable 'bogus_var'") +} + +func TestTiKVSystemVars(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + result := tk.MustQuery("SHOW GLOBAL VARIABLES LIKE 'tidb_gc_enable'") // default is on from the sysvar + result.Check(testkit.Rows("tidb_gc_enable ON")) + result = tk.MustQuery("SELECT variable_value FROM mysql.tidb WHERE variable_name = 'tikv_gc_enable'") + result.Check(testkit.Rows()) // but no value in the table (yet) because the value has not been set and the GC has never been run + + // update will set a value in the table + tk.MustExec("SET GLOBAL tidb_gc_enable = 1") + result = tk.MustQuery("SELECT variable_value FROM mysql.tidb WHERE variable_name = 'tikv_gc_enable'") + result.Check(testkit.Rows("true")) + + tk.MustExec("UPDATE mysql.tidb SET variable_value = 'false' WHERE variable_name='tikv_gc_enable'") + result = tk.MustQuery("SELECT @@tidb_gc_enable;") + result.Check(testkit.Rows("0")) // reads from mysql.tidb value and changes to false + + tk.MustExec("SET GLOBAL tidb_gc_concurrency = -1") // sets auto concurrency and concurrency + result = tk.MustQuery("SELECT variable_value FROM mysql.tidb WHERE variable_name = 'tikv_gc_auto_concurrency'") + result.Check(testkit.Rows("true")) + result = tk.MustQuery("SELECT variable_value FROM mysql.tidb WHERE variable_name = 'tikv_gc_concurrency'") + result.Check(testkit.Rows("-1")) + + tk.MustExec("SET GLOBAL tidb_gc_concurrency = 5") // sets auto concurrency and concurrency + result = tk.MustQuery("SELECT variable_value FROM mysql.tidb WHERE variable_name = 'tikv_gc_auto_concurrency'") + result.Check(testkit.Rows("false")) + result = tk.MustQuery("SELECT variable_value FROM mysql.tidb WHERE variable_name = 'tikv_gc_concurrency'") + result.Check(testkit.Rows("5")) + + tk.MustExec("UPDATE mysql.tidb SET variable_value = 'true' WHERE variable_name='tikv_gc_auto_concurrency'") + result = tk.MustQuery("SELECT @@tidb_gc_concurrency;") + result.Check(testkit.Rows("-1")) // because auto_concurrency is turned on it takes precedence + + tk.MustExec("REPLACE INTO mysql.tidb (variable_value, variable_name) VALUES ('15m', 'tikv_gc_run_interval')") + result = tk.MustQuery("SELECT @@GLOBAL.tidb_gc_run_interval;") + result.Check(testkit.Rows("15m0s")) + result = tk.MustQuery("SHOW GLOBAL VARIABLES LIKE 'tidb_gc_run_interval'") + result.Check(testkit.Rows("tidb_gc_run_interval 15m0s")) + + tk.MustExec("SET GLOBAL tidb_gc_run_interval = '9m'") // too small + tk.MustQuery("SHOW WARNINGS").Check(testkit.Rows("Warning 1292 Truncated incorrect tidb_gc_run_interval value: '9m'")) + result = tk.MustQuery("SHOW GLOBAL VARIABLES LIKE 'tidb_gc_run_interval'") + result.Check(testkit.Rows("tidb_gc_run_interval 10m0s")) + + tk.MustExec("SET GLOBAL tidb_gc_run_interval = '700000000000ns'") // specified in ns, also valid + + _, err := tk.Exec("SET GLOBAL tidb_gc_run_interval = '11mins'") + require.Equal(t, "[variable:1232]Incorrect argument type to variable 'tidb_gc_run_interval'", err.Error()) +} + +func TestUpgradeSysvars(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + se := tk.Session().(variable.GlobalVarAccessor) + + // Set the global var to a non-canonical form of the value + // i.e. implying that it was set from an earlier version of TiDB. + + tk.MustExec(`REPLACE INTO mysql.global_variables (variable_name, variable_value) VALUES ('tidb_enable_noop_functions', '0')`) + domain.GetDomain(tk.Session()).NotifyUpdateSysVarCache(true) // update cache + v, err := se.GetGlobalSysVar("tidb_enable_noop_functions") + require.NoError(t, err) + require.Equal(t, "OFF", v) + + // Set the global var to "" which is the invalid version of this from TiDB 4.0.16 + // the err is quashed by the GetGlobalSysVar, and the default value is restored. + // This helps callers of GetGlobalSysVar(), which can't individually be expected + // to handle upgrade/downgrade issues correctly. + + tk.MustExec(`REPLACE INTO mysql.global_variables (variable_name, variable_value) VALUES ('rpl_semi_sync_slave_enabled', '')`) + domain.GetDomain(tk.Session()).NotifyUpdateSysVarCache(true) // update cache + v, err = se.GetGlobalSysVar("rpl_semi_sync_slave_enabled") + require.NoError(t, err) + require.Equal(t, "OFF", v) // the default value is restored. + result := tk.MustQuery("SHOW VARIABLES LIKE 'rpl_semi_sync_slave_enabled'") + result.Check(testkit.Rows("rpl_semi_sync_slave_enabled OFF")) + + // Ensure variable out of range is converted to in range after upgrade. + // This further helps for https://github.com/pingcap/tidb/pull/28842 + + tk.MustExec(`REPLACE INTO mysql.global_variables (variable_name, variable_value) VALUES ('tidb_executor_concurrency', '999')`) + domain.GetDomain(tk.Session()).NotifyUpdateSysVarCache(true) // update cache + v, err = se.GetGlobalSysVar("tidb_executor_concurrency") + require.NoError(t, err) + require.Equal(t, "256", v) // the max value is restored. + + // Handle the case of a completely bogus value from an earlier version of TiDB. + // This could be the case if an ENUM sysvar removes a value. + + tk.MustExec(`REPLACE INTO mysql.global_variables (variable_name, variable_value) VALUES ('tidb_enable_noop_functions', 'SOMEVAL')`) + domain.GetDomain(tk.Session()).NotifyUpdateSysVarCache(true) // update cache + v, err = se.GetGlobalSysVar("tidb_enable_noop_functions") + require.NoError(t, err) + require.Equal(t, "OFF", v) // the default value is restored. +} + +func TestSetInstanceSysvarBySetGlobalSysVar(t *testing.T) { + varName := "tidb_general_log" + defaultValue := "OFF" // This is the default value for tidb_general_log + + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + se := tk.Session().(variable.GlobalVarAccessor) + + // Get globalSysVar twice and get the same default value + v, err := se.GetGlobalSysVar(varName) + require.NoError(t, err) + require.Equal(t, defaultValue, v) + v, err = se.GetGlobalSysVar(varName) + require.NoError(t, err) + require.Equal(t, defaultValue, v) + + // session.GetGlobalSysVar would not get the value which session.SetGlobalSysVar writes, + // because SetGlobalSysVar calls SetGlobalFromHook, which uses TiDBGeneralLog's SetGlobal, + // but GetGlobalSysVar could not access TiDBGeneralLog's GetGlobal. + + // set to "1" + err = se.SetGlobalSysVar(context.Background(), varName, "ON") + require.NoError(t, err) + v, err = se.GetGlobalSysVar(varName) + tk.MustQuery("select @@global.tidb_general_log").Check(testkit.Rows("1")) + require.NoError(t, err) + require.Equal(t, defaultValue, v) + + // set back to "0" + err = se.SetGlobalSysVar(context.Background(), varName, defaultValue) + require.NoError(t, err) + v, err = se.GetGlobalSysVar(varName) + tk.MustQuery("select @@global.tidb_general_log").Check(testkit.Rows("0")) + require.NoError(t, err) + require.Equal(t, defaultValue, v) +} + +func TestEnableLegacyInstanceScope(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + + // enable 'switching' to SESSION variables + tk.MustExec("set tidb_enable_legacy_instance_scope = 1") + tk.MustExec("set tidb_general_log = 1") + tk.MustQuery(`show warnings`).Check(testkit.Rows(fmt.Sprintf("Warning %d modifying tidb_general_log will require SET GLOBAL in a future version of TiDB", errno.ErrInstanceScope))) + require.True(t, tk.Session().GetSessionVars().EnableLegacyInstanceScope) + + // disable 'switching' to SESSION variables + tk.MustExec("set tidb_enable_legacy_instance_scope = 0") + tk.MustGetErrCode("set tidb_general_log = 1", errno.ErrGlobalVariable) + require.False(t, tk.Session().GetSessionVars().EnableLegacyInstanceScope) +} + +func TestSetPDClientDynamicOption(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustQuery("select @@tidb_tso_client_batch_max_wait_time;").Check(testkit.Rows("0")) + tk.MustExec("set global tidb_tso_client_batch_max_wait_time = 0.5;") + tk.MustQuery("select @@tidb_tso_client_batch_max_wait_time;").Check(testkit.Rows("0.5")) + tk.MustExec("set global tidb_tso_client_batch_max_wait_time = 1;") + tk.MustQuery("select @@tidb_tso_client_batch_max_wait_time;").Check(testkit.Rows("1")) + tk.MustExec("set global tidb_tso_client_batch_max_wait_time = 1.5;") + tk.MustQuery("select @@tidb_tso_client_batch_max_wait_time;").Check(testkit.Rows("1.5")) + tk.MustExec("set global tidb_tso_client_batch_max_wait_time = 10;") + tk.MustQuery("select @@tidb_tso_client_batch_max_wait_time;").Check(testkit.Rows("10")) + require.Error(t, tk.ExecToErr("set tidb_tso_client_batch_max_wait_time = 0;")) + tk.MustExec("set global tidb_tso_client_batch_max_wait_time = -1;") + tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect tidb_tso_client_batch_max_wait_time value: '-1'")) + tk.MustQuery("select @@tidb_tso_client_batch_max_wait_time;").Check(testkit.Rows("0")) + tk.MustExec("set global tidb_tso_client_batch_max_wait_time = -0.1;") + tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect tidb_tso_client_batch_max_wait_time value: '-0.1'")) + tk.MustQuery("select @@tidb_tso_client_batch_max_wait_time;").Check(testkit.Rows("0")) + tk.MustExec("set global tidb_tso_client_batch_max_wait_time = 10.1;") + tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect tidb_tso_client_batch_max_wait_time value: '10.1'")) + tk.MustQuery("select @@tidb_tso_client_batch_max_wait_time;").Check(testkit.Rows("10")) + tk.MustExec("set global tidb_tso_client_batch_max_wait_time = 11;") + tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect tidb_tso_client_batch_max_wait_time value: '11'")) + tk.MustQuery("select @@tidb_tso_client_batch_max_wait_time;").Check(testkit.Rows("10")) + + tk.MustQuery("select @@tidb_enable_tso_follower_proxy;").Check(testkit.Rows("0")) + tk.MustExec("set global tidb_enable_tso_follower_proxy = on;") + tk.MustQuery("select @@tidb_enable_tso_follower_proxy;").Check(testkit.Rows("1")) + tk.MustExec("set global tidb_enable_tso_follower_proxy = off;") + tk.MustQuery("select @@tidb_enable_tso_follower_proxy;").Check(testkit.Rows("0")) + require.Error(t, tk.ExecToErr("set tidb_tso_client_batch_max_wait_time = 0;")) +} + +func TestCastTimeToDate(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set time_zone = '-8:00'") + date := time.Now().In(time.FixedZone("", -8*int(time.Hour/time.Second))) + tk.MustQuery("select cast(time('12:23:34') as date)").Check(testkit.Rows(date.Format(time.DateOnly))) + + tk.MustExec("set time_zone = '+08:00'") + date = time.Now().In(time.FixedZone("", 8*int(time.Hour/time.Second))) + tk.MustQuery("select cast(time('12:23:34') as date)").Check(testkit.Rows(date.Format(time.DateOnly))) +} + +func TestSetGlobalTZ(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("set time_zone = '+08:00'") + tk.MustQuery("show variables like 'time_zone'").Check(testkit.Rows("time_zone +08:00")) + + tk.MustExec("set global time_zone = '+00:00'") + + tk.MustQuery("show variables like 'time_zone'").Check(testkit.Rows("time_zone +08:00")) + + tk1 := testkit.NewTestKit(t, store) + tk1.MustQuery("show variables like 'time_zone'").Check(testkit.Rows("time_zone +00:00")) +} +func TestSetVarHint(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("sql_mode", mysql.DefaultSQLMode)) + tk.MustQuery("SELECT /*+ SET_VAR(sql_mode=ALLOW_INVALID_DATES) */ @@sql_mode;").Check(testkit.Rows("ALLOW_INVALID_DATES")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@sql_mode;").Check(testkit.Rows(mysql.DefaultSQLMode)) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("tmp_table_size", "16777216")) + tk.MustQuery("SELECT /*+ SET_VAR(tmp_table_size=1024) */ @@tmp_table_size;").Check(testkit.Rows("1024")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@tmp_table_size;").Check(testkit.Rows("16777216")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("range_alloc_block_size", "4096")) + tk.MustQuery("SELECT /*+ SET_VAR(range_alloc_block_size=4294967295) */ @@range_alloc_block_size;").Check(testkit.Rows("4294967295")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@range_alloc_block_size;").Check(testkit.Rows("4096")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("max_execution_time", "0")) + tk.MustQuery("SELECT /*+ SET_VAR(max_execution_time=1) */ @@max_execution_time;").Check(testkit.Rows("1")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@max_execution_time;").Check(testkit.Rows("0")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("tidb_kv_read_timeout", "0")) + tk.MustQuery("SELECT /*+ SET_VAR(tidb_kv_read_timeout=10) */ @@tidb_kv_read_timeout;").Check(testkit.Rows("10")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@tidb_kv_read_timeout;").Check(testkit.Rows("0")) + + tk.MustExec("set @@tidb_kv_read_timeout = 5") + tk.MustQuery("SELECT /*+ tidb_kv_read_timeout(1) */ @@tidb_kv_read_timeout;").Check(testkit.Rows("5")) + require.Equal(t, tk.Session().GetSessionVars().GetTidbKvReadTimeout(), uint64(1)) + tk.MustQuery("SELECT @@tidb_kv_read_timeout;").Check(testkit.Rows("5")) + require.Equal(t, tk.Session().GetSessionVars().GetTidbKvReadTimeout(), uint64(5)) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("time_zone", "SYSTEM")) + tk.MustQuery("SELECT /*+ SET_VAR(time_zone='+12:00') */ @@time_zone;").Check(testkit.Rows("+12:00")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@time_zone;").Check(testkit.Rows("SYSTEM")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("join_buffer_size", "262144")) + tk.MustQuery("SELECT /*+ SET_VAR(join_buffer_size=128) */ @@join_buffer_size;").Check(testkit.Rows("128")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@join_buffer_size;").Check(testkit.Rows("262144")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("max_length_for_sort_data", "1024")) + tk.MustQuery("SELECT /*+ SET_VAR(max_length_for_sort_data=4) */ @@max_length_for_sort_data;").Check(testkit.Rows("4")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@max_length_for_sort_data;").Check(testkit.Rows("1024")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("max_error_count", "64")) + tk.MustQuery("SELECT /*+ SET_VAR(max_error_count=0) */ @@max_error_count;").Check(testkit.Rows("0")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@max_error_count;").Check(testkit.Rows("64")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("sql_buffer_result", "OFF")) + tk.MustQuery("SELECT /*+ SET_VAR(sql_buffer_result=ON) */ @@sql_buffer_result;").Check(testkit.Rows("ON")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@sql_buffer_result;").Check(testkit.Rows("OFF")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("max_heap_table_size", "16777216")) + tk.MustQuery("SELECT /*+ SET_VAR(max_heap_table_size=16384) */ @@max_heap_table_size;").Check(testkit.Rows("16384")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@max_heap_table_size;").Check(testkit.Rows("16777216")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("tmp_table_size", "16777216")) + tk.MustQuery("SELECT /*+ SET_VAR(tmp_table_size=16384) */ @@tmp_table_size;").Check(testkit.Rows("16384")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@tmp_table_size;").Check(testkit.Rows("16777216")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("div_precision_increment", "4")) + tk.MustQuery("SELECT /*+ SET_VAR(div_precision_increment=0) */ @@div_precision_increment;").Check(testkit.Rows("0")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@div_precision_increment;").Check(testkit.Rows("4")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("sql_auto_is_null", "OFF")) + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("tidb_enable_noop_functions", "ON")) + tk.MustQuery("SELECT /*+ SET_VAR(sql_auto_is_null=1) */ @@sql_auto_is_null;").Check(testkit.Rows("1")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + require.NoError(t, tk.Session().GetSessionVars().SetSystemVarWithoutValidation("tidb_enable_noop_functions", "OFF")) + tk.MustQuery("SELECT @@sql_auto_is_null;").Check(testkit.Rows("0")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("sort_buffer_size", "262144")) + tk.MustQuery("SELECT /*+ SET_VAR(sort_buffer_size=32768) */ @@sort_buffer_size;").Check(testkit.Rows("32768")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@sort_buffer_size;").Check(testkit.Rows("262144")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("max_join_size", "18446744073709551615")) + tk.MustQuery("SELECT /*+ SET_VAR(max_join_size=1) */ @@max_join_size;").Check(testkit.Rows("1")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@max_join_size;").Check(testkit.Rows("18446744073709551615")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("max_seeks_for_key", "18446744073709551615")) + tk.MustQuery("SELECT /*+ SET_VAR(max_seeks_for_key=1) */ @@max_seeks_for_key;").Check(testkit.Rows("1")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@max_seeks_for_key;").Check(testkit.Rows("18446744073709551615")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("max_sort_length", "1024")) + tk.MustQuery("SELECT /*+ SET_VAR(max_sort_length=4) */ @@max_sort_length;").Check(testkit.Rows("4")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@max_sort_length;").Check(testkit.Rows("1024")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("bulk_insert_buffer_size", "8388608")) + tk.MustQuery("SELECT /*+ SET_VAR(bulk_insert_buffer_size=0) */ @@bulk_insert_buffer_size;").Check(testkit.Rows("0")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@bulk_insert_buffer_size;").Check(testkit.Rows("8388608")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("sql_big_selects", "1")) + tk.MustQuery("SELECT /*+ SET_VAR(sql_big_selects=0) */ @@sql_big_selects;").Check(testkit.Rows("0")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@sql_big_selects;").Check(testkit.Rows("1")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("read_rnd_buffer_size", "262144")) + tk.MustQuery("SELECT /*+ SET_VAR(read_rnd_buffer_size=1) */ @@read_rnd_buffer_size;").Check(testkit.Rows("1")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@read_rnd_buffer_size;").Check(testkit.Rows("262144")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("unique_checks", "1")) + tk.MustQuery("SELECT /*+ SET_VAR(unique_checks=0) */ @@unique_checks;").Check(testkit.Rows("0")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@unique_checks;").Check(testkit.Rows("1")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("read_buffer_size", "131072")) + tk.MustQuery("SELECT /*+ SET_VAR(read_buffer_size=8192) */ @@read_buffer_size;").Check(testkit.Rows("8192")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@read_buffer_size;").Check(testkit.Rows("131072")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("default_tmp_storage_engine", "InnoDB")) + tk.MustQuery("SELECT /*+ SET_VAR(default_tmp_storage_engine='CSV') */ @@default_tmp_storage_engine;").Check(testkit.Rows("CSV")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@default_tmp_storage_engine;").Check(testkit.Rows("InnoDB")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("optimizer_search_depth", "62")) + tk.MustQuery("SELECT /*+ SET_VAR(optimizer_search_depth=1) */ @@optimizer_search_depth;").Check(testkit.Rows("1")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@optimizer_search_depth;").Check(testkit.Rows("62")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("max_points_in_geometry", "65536")) + tk.MustQuery("SELECT /*+ SET_VAR(max_points_in_geometry=3) */ @@max_points_in_geometry;").Check(testkit.Rows("3")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@max_points_in_geometry;").Check(testkit.Rows("65536")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("updatable_views_with_limit", "YES")) + tk.MustQuery("SELECT /*+ SET_VAR(updatable_views_with_limit=0) */ @@updatable_views_with_limit;").Check(testkit.Rows("0")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@updatable_views_with_limit;").Check(testkit.Rows("YES")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("optimizer_prune_level", "1")) + tk.MustQuery("SELECT /*+ SET_VAR(optimizer_prune_level=0) */ @@optimizer_prune_level;").Check(testkit.Rows("0")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@optimizer_prune_level;").Check(testkit.Rows("1")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("group_concat_max_len", "1024")) + tk.MustQuery("SELECT /*+ SET_VAR(group_concat_max_len=4) */ @@group_concat_max_len;").Check(testkit.Rows("4")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@group_concat_max_len;").Check(testkit.Rows("1024")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("eq_range_index_dive_limit", "200")) + tk.MustQuery("SELECT /*+ SET_VAR(eq_range_index_dive_limit=0) */ @@eq_range_index_dive_limit;").Check(testkit.Rows("0")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@eq_range_index_dive_limit;").Check(testkit.Rows("200")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("sql_safe_updates", "0")) + tk.MustQuery("SELECT /*+ SET_VAR(sql_safe_updates=1) */ @@sql_safe_updates;").Check(testkit.Rows("1")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@sql_safe_updates;").Check(testkit.Rows("0")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("end_markers_in_json", "0")) + tk.MustQuery("SELECT /*+ SET_VAR(end_markers_in_json=1) */ @@end_markers_in_json;").Check(testkit.Rows("1")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@end_markers_in_json;").Check(testkit.Rows("0")) + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("windowing_use_high_precision", "ON")) + tk.MustQuery("SELECT /*+ SET_VAR(windowing_use_high_precision=OFF) */ @@windowing_use_high_precision;").Check(testkit.Rows("0")) + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + tk.MustQuery("SELECT @@windowing_use_high_precision;").Check(testkit.Rows("1")) + + tk.MustExec("SELECT /*+ SET_VAR(sql_safe_updates = 1) SET_VAR(max_heap_table_size = 1G) */ 1;") + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 0) + + tk.MustExec("SELECT /*+ SET_VAR(collation_server = 'utf8') */ 1;") + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 1) + require.EqualError(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings()[0].Err, "[planner:3637]Variable 'collation_server' cannot be set using SET_VAR hint.") + + tk.MustExec("SELECT /*+ SET_VAR(max_size = 1G) */ 1;") + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 1) + require.EqualError(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings()[0].Err, "[planner:3128]Unresolved name 'max_size' for SET_VAR hint") + + tk.MustExec("SELECT /*+ SET_VAR(group_concat_max_len = 1024) SET_VAR(group_concat_max_len = 2048) */ 1;") + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 1) + require.EqualError(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings()[0].Err, "[planner:3126]Hint SET_VAR(group_concat_max_len=2048) is ignored as conflicting/duplicated.") +} + +func TestGlobalVarAccessor(t *testing.T) { + varName := "max_allowed_packet" + varValue := strconv.FormatUint(variable.DefMaxAllowedPacket, 10) // This is the default value for max_allowed_packet + + // The value of max_allowed_packet should be a multiple of 1024, + // so the setting of varValue1 and varValue2 would be truncated to varValue0 + varValue0 := "4194304" + varValue1 := "4194305" + varValue2 := "4194306" + + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + se := tk.Session().(variable.GlobalVarAccessor) + // Get globalSysVar twice and get the same value + v, err := se.GetGlobalSysVar(varName) + require.NoError(t, err) + require.Equal(t, varValue, v) + v, err = se.GetGlobalSysVar(varName) + require.NoError(t, err) + require.Equal(t, varValue, v) + // Set global var to another value + err = se.SetGlobalSysVar(context.Background(), varName, varValue1) + require.NoError(t, err) + v, err = se.GetGlobalSysVar(varName) + require.NoError(t, err) + require.Equal(t, varValue0, v) + require.NoError(t, tk.Session().CommitTxn(context.TODO())) + + tk1 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + se1 := tk1.Session().(variable.GlobalVarAccessor) + v, err = se1.GetGlobalSysVar(varName) + require.NoError(t, err) + require.Equal(t, varValue0, v) + err = se1.SetGlobalSysVar(context.Background(), varName, varValue2) + require.NoError(t, err) + v, err = se1.GetGlobalSysVar(varName) + require.NoError(t, err) + require.Equal(t, varValue0, v) + require.NoError(t, tk1.Session().CommitTxn(context.TODO())) + + // Make sure the change is visible to any client that accesses that global variable. + v, err = se.GetGlobalSysVar(varName) + require.NoError(t, err) + require.Equal(t, varValue0, v) + + // For issue 10955, make sure the new session load `max_execution_time` into sessionVars. + tk1.MustExec("set @@global.max_execution_time = 100") + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("use test") + require.Equal(t, uint64(100), tk2.Session().GetSessionVars().MaxExecutionTime) + tk1.MustExec("set @@global.max_execution_time = 0") + + result := tk.MustQuery("show global variables where variable_name='sql_select_limit';") + result.Check(testkit.Rows("sql_select_limit 18446744073709551615")) + result = tk.MustQuery("show session variables where variable_name='sql_select_limit';") + result.Check(testkit.Rows("sql_select_limit 18446744073709551615")) + tk.MustExec("set session sql_select_limit=100000000000;") + result = tk.MustQuery("show global variables where variable_name='sql_select_limit';") + result.Check(testkit.Rows("sql_select_limit 18446744073709551615")) + result = tk.MustQuery("show session variables where variable_name='sql_select_limit';") + result.Check(testkit.Rows("sql_select_limit 100000000000")) + tk.MustExec("set @@global.sql_select_limit = 1") + result = tk.MustQuery("show global variables where variable_name='sql_select_limit';") + result.Check(testkit.Rows("sql_select_limit 1")) + tk.MustExec("set @@global.sql_select_limit = default") + result = tk.MustQuery("show global variables where variable_name='sql_select_limit';") + result.Check(testkit.Rows("sql_select_limit 18446744073709551615")) + + result = tk.MustQuery("select @@global.autocommit;") + result.Check(testkit.Rows("1")) + result = tk.MustQuery("select @@autocommit;") + result.Check(testkit.Rows("1")) + tk.MustExec("set @@global.autocommit = 0;") + result = tk.MustQuery("select @@global.autocommit;") + result.Check(testkit.Rows("0")) + result = tk.MustQuery("select @@autocommit;") + result.Check(testkit.Rows("1")) + tk.MustExec("set @@global.autocommit=1") + + err = tk.ExecToErr("set global time_zone = 'timezone'") + require.Error(t, err) + require.True(t, terror.ErrorEqual(err, variable.ErrUnknownTimeZone)) +} + +func TestGetSysVariables(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + // Test ScopeSession + tk.MustExec("select @@warning_count") + tk.MustExec("select @@session.warning_count") + tk.MustExec("select @@local.warning_count") + err := tk.ExecToErr("select @@global.warning_count") + require.True(t, terror.ErrorEqual(err, variable.ErrIncorrectScope), fmt.Sprintf("err %v", err)) + + // Test ScopeGlobal + tk.MustExec("select @@max_connections") + tk.MustExec("select @@global.max_connections") + tk.MustGetErrMsg("select @@session.max_connections", "[variable:1238]Variable 'max_connections' is a GLOBAL variable") + tk.MustGetErrMsg("select @@local.max_connections", "[variable:1238]Variable 'max_connections' is a GLOBAL variable") + + // Test ScopeNone + tk.MustExec("select @@performance_schema_max_mutex_classes") + tk.MustExec("select @@global.performance_schema_max_mutex_classes") + // For issue 19524, test + tk.MustExec("select @@session.performance_schema_max_mutex_classes") + tk.MustExec("select @@local.performance_schema_max_mutex_classes") + tk.MustGetErrMsg("select @@global.last_insert_id", "[variable:1238]Variable 'last_insert_id' is a SESSION variable") +} + +func TestPrepareExecuteWithSQLHints(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + se := tk.Session() + se.SetConnectionID(1) + tk.MustExec("use test") + tk.MustExec("create table t(a int primary key)") + + type hintCheck struct { + hint string + check func(*stmtctx.StmtHints) + } + + hintChecks := []hintCheck{ + { + hint: "MEMORY_QUOTA(1024 MB)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasMemQuotaHint) + require.Equal(t, int64(1024*1024*1024), stmtHint.MemQuotaQuery) + }, + }, + { + hint: "READ_CONSISTENT_REPLICA()", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasReplicaReadHint) + require.Equal(t, byte(tikv.ReplicaReadFollower), stmtHint.ReplicaRead) + }, + }, + { + hint: "MAX_EXECUTION_TIME(1000)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasMaxExecutionTime) + require.Equal(t, uint64(1000), stmtHint.MaxExecutionTime) + }, + }, + { + hint: "USE_TOJA(TRUE)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasAllowInSubqToJoinAndAggHint) + require.True(t, stmtHint.AllowInSubqToJoinAndAgg) + }, + }, + { + hint: "RESOURCE_GROUP(rg1)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasResourceGroup) + require.Equal(t, "rg1", stmtHint.ResourceGroup) + }, + }, + } + + for i, check := range hintChecks { + // common path + tk.MustExec(fmt.Sprintf("prepare stmt%d from 'select /*+ %s */ * from t'", i, check.hint)) + for j := 0; j < 10; j++ { + tk.MustQuery(fmt.Sprintf("execute stmt%d", i)) + check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints) + } + // fast path + tk.MustExec(fmt.Sprintf("prepare fast%d from 'select /*+ %s */ * from t where a = 1'", i, check.hint)) + for j := 0; j < 10; j++ { + tk.MustQuery(fmt.Sprintf("execute fast%d", i)) + check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints) + } + } +} diff --git a/sessionctx/stmtctx/BUILD.bazel b/sessionctx/stmtctx/BUILD.bazel new file mode 100644 index 0000000000000..e16d9a44da20d --- /dev/null +++ b/sessionctx/stmtctx/BUILD.bazel @@ -0,0 +1,54 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "stmtctx", + srcs = ["stmtctx.go"], + importpath = "github.com/pingcap/tidb/sessionctx/stmtctx", + visibility = ["//visibility:public"], + deps = [ + "//domain/resourcegroup", + "//errno", + "//parser", + "//parser/ast", + "//parser/model", + "//parser/mysql", + "//parser/terror", + "//util/disk", + "//util/execdetails", + "//util/memory", + "//util/resourcegrouptag", + "//util/topsql/stmtstats", + "//util/tracing", + "@com_github_pingcap_errors//:errors", + "@com_github_tikv_client_go_v2//tikvrpc", + "@com_github_tikv_client_go_v2//util", + "@org_golang_x_exp//maps", + "@org_golang_x_exp//slices", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "stmtctx_test", + timeout = "short", + srcs = [ + "main_test.go", + "stmtctx_test.go", + ], + embed = [":stmtctx"], + flaky = True, + shard_count = 6, + deps = [ + "//kv", + "//sessionctx/variable", + "//testkit", + "//testkit/testsetup", + "//util/execdetails", + "@com_github_pingcap_errors//:errors", + "@com_github_stretchr_testify//require", + "@com_github_tikv_client_go_v2//util", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_goleak//:goleak", + ], +) diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index d5f19e8fe6dd5..09dc755d0deb1 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -300,7 +300,6 @@ type StatementContext struct { type StmtHints struct { // Hint Information MemQuotaQuery int64 - ApplyCacheCapacity int64 MaxExecutionTime uint64 ReplicaRead byte AllowInSubqToJoinAndAgg bool @@ -329,6 +328,45 @@ func (sh *StmtHints) TaskMapNeedBackUp() bool { return sh.ForceNthPlan != -1 } +// Clone the StmtHints struct and returns the pointer of the new one. +func (sh *StmtHints) Clone() *StmtHints { + var ( + vars map[string]string + tableHints []*ast.TableOptimizerHint + ) + if len(sh.SetVars) > 0 { + vars = make(map[string]string, len(sh.SetVars)) + for k, v := range sh.SetVars { + vars[k] = v + } + } + if len(sh.OriginalTableHints) > 0 { + tableHints = make([]*ast.TableOptimizerHint, len(sh.OriginalTableHints)) + copy(tableHints, sh.OriginalTableHints) + } + return &StmtHints{ + MemQuotaQuery: sh.MemQuotaQuery, + MaxExecutionTime: sh.MaxExecutionTime, + TidbKvReadTimeout: sh.TidbKvReadTimeout, + ReplicaRead: sh.ReplicaRead, + AllowInSubqToJoinAndAgg: sh.AllowInSubqToJoinAndAgg, + NoIndexMergeHint: sh.NoIndexMergeHint, + StraightJoinOrder: sh.StraightJoinOrder, + EnableCascadesPlanner: sh.EnableCascadesPlanner, + ForceNthPlan: sh.ForceNthPlan, + ResourceGroup: sh.ResourceGroup, + HasAllowInSubqToJoinAndAggHint: sh.HasAllowInSubqToJoinAndAggHint, + HasMemQuotaHint: sh.HasMemQuotaHint, + HasReplicaReadHint: sh.HasReplicaReadHint, + HasMaxExecutionTime: sh.HasMaxExecutionTime, + HasTidbKvReadTimeout: sh.HasTidbKvReadTimeout, + HasEnableCascadesPlannerHint: sh.HasEnableCascadesPlannerHint, + HasResourceGroup: sh.HasResourceGroup, + SetVars: vars, + OriginalTableHints: tableHints, + } +} + // StmtCacheKey represents the key type in the StmtCache. type StmtCacheKey int diff --git a/sessionctx/stmtctx/stmtctx_test.go b/sessionctx/stmtctx/stmtctx_test.go index 7a4ec77a90660..9151dff64320e 100644 --- a/sessionctx/stmtctx/stmtctx_test.go +++ b/sessionctx/stmtctx/stmtctx_test.go @@ -17,6 +17,12 @@ package stmtctx_test import ( "context" "fmt" +<<<<<<< HEAD +======= + "math/rand" + "reflect" + "sort" +>>>>>>> c34f6fc83d6 (planner: store the hints of session variable (#45814)) "testing" "time" @@ -143,3 +149,151 @@ func TestWeakConsistencyRead(t *testing.T) { execAndCheck("execute s", testkit.Rows("1 1 2"), kv.SI) tk.MustExec("rollback") } +<<<<<<< HEAD +======= + +func TestMarshalSQLWarn(t *testing.T) { + warns := []stmtctx.SQLWarn{ + { + Level: stmtctx.WarnLevelError, + Err: errors.New("any error"), + }, + { + Level: stmtctx.WarnLevelError, + Err: errors.Trace(errors.New("any error")), + }, + { + Level: stmtctx.WarnLevelWarning, + Err: variable.ErrUnknownSystemVar.GenWithStackByArgs("unknown"), + }, + { + Level: stmtctx.WarnLevelWarning, + Err: errors.Trace(variable.ErrUnknownSystemVar.GenWithStackByArgs("unknown")), + }, + } + + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + // First query can trigger loading global variables, which produces warnings. + tk.MustQuery("select 1") + tk.Session().GetSessionVars().StmtCtx.SetWarnings(warns) + rows := tk.MustQuery("show warnings").Rows() + require.Equal(t, len(warns), len(rows)) + + // The unmarshalled result doesn't need to be exactly the same with the original one. + // We only need that the results of `show warnings` are the same. + bytes, err := json.Marshal(warns) + require.NoError(t, err) + var newWarns []stmtctx.SQLWarn + err = json.Unmarshal(bytes, &newWarns) + require.NoError(t, err) + tk.Session().GetSessionVars().StmtCtx.SetWarnings(newWarns) + tk.MustQuery("show warnings").Check(rows) +} + +func TestApproxRuntimeInfo(t *testing.T) { + var n = rand.Intn(19000) + 1000 + var valRange = rand.Int31n(10000) + 1000 + backoffs := []string{"tikvRPC", "pdRPC", "regionMiss"} + details := []*execdetails.ExecDetails{} + for i := 0; i < n; i++ { + d := &execdetails.ExecDetails{ + DetailsNeedP90: execdetails.DetailsNeedP90{ + CalleeAddress: fmt.Sprintf("%v", i+1), + BackoffSleep: make(map[string]time.Duration), + BackoffTimes: make(map[string]int), + TimeDetail: util.TimeDetail{ + ProcessTime: time.Second * time.Duration(rand.Int31n(valRange)), + WaitTime: time.Millisecond * time.Duration(rand.Int31n(valRange)), + }, + }, + } + details = append(details, d) + for _, backoff := range backoffs { + d.BackoffSleep[backoff] = time.Millisecond * 100 * time.Duration(rand.Int31n(valRange)) + d.BackoffTimes[backoff] = rand.Intn(int(valRange)) + } + } + + // Make CalleeAddress for each max value is deterministic. + details[rand.Intn(n)].DetailsNeedP90.TimeDetail.ProcessTime = time.Second * time.Duration(valRange) + details[rand.Intn(n)].DetailsNeedP90.TimeDetail.WaitTime = time.Millisecond * time.Duration(valRange) + for _, backoff := range backoffs { + details[rand.Intn(n)].BackoffSleep[backoff] = time.Millisecond * 100 * time.Duration(valRange) + } + + ctx := new(stmtctx.StatementContext) + for i := 0; i < n; i++ { + ctx.MergeExecDetails(details[i], nil) + } + d := ctx.CopTasksDetails() + + require.Equal(t, d.NumCopTasks, n) + sort.Slice(details, func(i, j int) bool { + return details[i].TimeDetail.ProcessTime.Nanoseconds() < details[j].TimeDetail.ProcessTime.Nanoseconds() + }) + var timeSum time.Duration + for _, detail := range details { + timeSum += detail.TimeDetail.ProcessTime + } + require.Equal(t, d.AvgProcessTime, timeSum/time.Duration(n)) + require.InEpsilon(t, d.P90ProcessTime.Nanoseconds(), details[n*9/10].TimeDetail.ProcessTime.Nanoseconds(), 0.05) + require.Equal(t, d.MaxProcessTime, details[n-1].TimeDetail.ProcessTime) + require.Equal(t, d.MaxProcessAddress, details[n-1].CalleeAddress) + + sort.Slice(details, func(i, j int) bool { + return details[i].TimeDetail.WaitTime.Nanoseconds() < details[j].TimeDetail.WaitTime.Nanoseconds() + }) + timeSum = 0 + for _, detail := range details { + timeSum += detail.TimeDetail.WaitTime + } + require.Equal(t, d.AvgWaitTime, timeSum/time.Duration(n)) + require.InEpsilon(t, d.P90WaitTime.Nanoseconds(), details[n*9/10].TimeDetail.WaitTime.Nanoseconds(), 0.05) + require.Equal(t, d.MaxWaitTime, details[n-1].TimeDetail.WaitTime) + require.Equal(t, d.MaxWaitAddress, details[n-1].CalleeAddress) + + fields := d.ToZapFields() + require.Equal(t, 9, len(fields)) + for _, backoff := range backoffs { + sort.Slice(details, func(i, j int) bool { + return details[i].BackoffSleep[backoff].Nanoseconds() < details[j].BackoffSleep[backoff].Nanoseconds() + }) + timeSum = 0 + var timesSum = 0 + for _, detail := range details { + timeSum += detail.BackoffSleep[backoff] + timesSum += detail.BackoffTimes[backoff] + } + require.Equal(t, d.MaxBackoffAddress[backoff], details[n-1].CalleeAddress) + require.Equal(t, d.MaxBackoffTime[backoff], details[n-1].BackoffSleep[backoff]) + require.InEpsilon(t, d.P90BackoffTime[backoff], details[n*9/10].BackoffSleep[backoff], 0.1) + require.Equal(t, d.AvgBackoffTime[backoff], timeSum/time.Duration(n)) + + require.Equal(t, d.TotBackoffTimes[backoff], timesSum) + require.Equal(t, d.TotBackoffTime[backoff], timeSum) + } +} + +func TestStmtHintsClone(t *testing.T) { + hints := stmtctx.StmtHints{} + value := reflect.ValueOf(&hints).Elem() + for i := 0; i < value.NumField(); i++ { + field := value.Field(i) + switch field.Kind() { + case reflect.Int, reflect.Int32, reflect.Int64: + field.SetInt(1) + case reflect.Uint, reflect.Uint32, reflect.Uint64: + field.SetUint(1) + case reflect.Uint8: // byte + field.SetUint(1) + case reflect.Bool: + field.SetBool(true) + case reflect.String: + field.SetString("test") + default: + } + } + require.Equal(t, hints, *hints.Clone()) +} +>>>>>>> c34f6fc83d6 (planner: store the hints of session variable (#45814))