diff --git a/executor/calibrate_resource.go b/executor/calibrate_resource.go index bba3a714f599e..0ca6f5cba29f9 100644 --- a/executor/calibrate_resource.go +++ b/executor/calibrate_resource.go @@ -132,19 +132,19 @@ type calibrateResourceExec struct { done bool } -func (e *calibrateResourceExec) parseCalibrateDuration() (startTime time.Time, endTime time.Time, err error) { +func (e *calibrateResourceExec) parseCalibrateDuration(ctx context.Context) (startTime time.Time, endTime time.Time, err error) { var dur time.Duration var ts uint64 for _, op := range e.optionList { switch op.Tp { case ast.CalibrateStartTime: - ts, err = staleread.CalculateAsOfTsExpr(e.ctx, op.Ts) + ts, err = staleread.CalculateAsOfTsExpr(ctx, e.ctx, op.Ts) if err != nil { return } startTime = oracle.GetTimeFromTS(ts) case ast.CalibrateEndTime: - ts, err = staleread.CalculateAsOfTsExpr(e.ctx, op.Ts) + ts, err = staleread.CalculateAsOfTsExpr(ctx, e.ctx, op.Ts) if err != nil { return } @@ -197,7 +197,7 @@ func (e *calibrateResourceExec) Next(ctx context.Context, req *chunk.Chunk) erro } func (e *calibrateResourceExec) dynamicCalibrate(ctx context.Context, req *chunk.Chunk, exec sqlexec.RestrictedSQLExecutor) error { - startTs, endTs, err := e.parseCalibrateDuration() + startTs, endTs, err := e.parseCalibrateDuration(ctx) if err != nil { return err } diff --git a/executor/ddl.go b/executor/ddl.go index 97bf58c0a3c85..7dc2ab47de971 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -538,7 +538,7 @@ func (e *DDLExec) getRecoverTableByTableName(tableName *ast.TableName) (*model.J } func (e *DDLExec) executeFlashBackCluster(s *ast.FlashBackToTimestampStmt) error { - flashbackTS, err := staleread.CalculateAsOfTsExpr(e.ctx, s.FlashbackTS) + flashbackTS, err := staleread.CalculateAsOfTsExpr(context.Background(), e.ctx, s.FlashbackTS) if err != nil { return err } diff --git a/executor/stale_txn_test.go b/executor/stale_txn_test.go index dc1ab4ff962a8..e621c33ccc675 100644 --- a/executor/stale_txn_test.go +++ b/executor/stale_txn_test.go @@ -1396,3 +1396,28 @@ func TestStalePrepare(t *testing.T) { tk.MustQuery("execute stmt").Check(expected) } } + +func TestStaleTSO(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + defer tk.MustExec("drop table if exists t") + tk.MustExec("create table t (id int)") + + tk.MustExec("insert into t values(1)") + + asOfExprs := []string{ + "now(3) - interval 1 second", + "current_time() - interval 1 second", + "curtime() - interval 1 second", + } + + nextTSO := oracle.GoTimeToTS(time.Now().Add(2 * time.Second)) + require.Nil(t, failpoint.Enable("github.com/pingcap/tidb/sessiontxn/staleread/mockStaleReadTSO", fmt.Sprintf("return(%d)", nextTSO))) + defer failpoint.Disable("github.com/pingcap/tidb/sessiontxn/staleread/mockStaleReadTSO") + for _, expr := range asOfExprs { + // Make sure the now() expr is evaluated from the stale ts provider. + tk.MustQuery("select * from t as of timestamp " + expr + " order by id asc").Check(testkit.Rows("1")) + } +} diff --git a/expression/helper.go b/expression/helper.go index f286ee644209a..c96dabe556300 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" driver "github.com/pingcap/tidb/types/parser_driver" + "github.com/tikv/client-go/v2/oracle" ) func boolToInt64(v bool) int64 { @@ -158,6 +159,13 @@ func getStmtTimestamp(ctx sessionctx.Context) (time.Time, error) { failpoint.Return(v, nil) }) + if ctx != nil { + staleTSO, err := ctx.GetSessionVars().StmtCtx.GetStaleTSO() + if staleTSO != 0 && err == nil { + return oracle.GetTimeFromTS(staleTSO), nil + } + } + now := time.Now() if ctx == nil { diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 3679e7fff7240..06f85d5573ec3 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -3549,7 +3549,7 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan, case *ast.BeginStmt: readTS := b.ctx.GetSessionVars().TxnReadTS.PeakTxnReadTS() if raw.AsOf != nil { - startTS, err := staleread.CalculateAsOfTsExpr(b.ctx, raw.AsOf.TsExpr) + startTS, err := staleread.CalculateAsOfTsExpr(ctx, b.ctx, raw.AsOf.TsExpr) if err != nil { return nil, err } diff --git a/sessionctx/context.go b/sessionctx/context.go index b8b10511aed72..65eda41aa058b 100644 --- a/sessionctx/context.go +++ b/sessionctx/context.go @@ -244,7 +244,10 @@ const allowedTimeFromNow = 100 * time.Millisecond // ValidateStaleReadTS validates that readTS does not exceed the current time not strictly. func ValidateStaleReadTS(ctx context.Context, sctx Context, readTS uint64) error { - currentTS, err := sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0) + currentTS, err := sctx.GetSessionVars().StmtCtx.GetStaleTSO() + if currentTS == 0 || err != nil { + currentTS, err = sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0) + } // If we fail to calculate currentTS from local time, fallback to get a timestamp from PD if err != nil { metrics.ValidateReadTSFromPDCount.Inc() diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 26e48a3f70fe7..27359d5b0ccb2 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -405,6 +405,12 @@ type StatementContext struct { TiFlashEngineRemovedDueToStrictSQLMode bool // CanonicalHashCode try to get the canonical hash code from expression. CanonicalHashCode bool + // StaleTSOProvider is used to provide stale timestamp oracle for read-only transactions. + StaleTSOProvider struct { + sync.Mutex + value *uint64 + eval func() (uint64, error) + } } // StmtHints are SessionVars related sql hints. @@ -1229,6 +1235,32 @@ func (sc *StatementContext) DetachMemDiskTracker() { } } +// SetStaleTSOProvider sets the stale TSO provider. +func (sc *StatementContext) SetStaleTSOProvider(eval func() (uint64, error)) { + sc.StaleTSOProvider.Lock() + defer sc.StaleTSOProvider.Unlock() + sc.StaleTSOProvider.value = nil + sc.StaleTSOProvider.eval = eval +} + +// GetStaleTSO returns the TSO for stale-read usage which calculate from PD's last response. +func (sc *StatementContext) GetStaleTSO() (uint64, error) { + sc.StaleTSOProvider.Lock() + defer sc.StaleTSOProvider.Unlock() + if sc.StaleTSOProvider.value != nil { + return *sc.StaleTSOProvider.value, nil + } + if sc.StaleTSOProvider.eval == nil { + return 0, nil + } + tso, err := sc.StaleTSOProvider.eval() + if err != nil { + return 0, err + } + sc.StaleTSOProvider.value = &tso + return tso, nil +} + // CopTasksDetails collects some useful information of cop-tasks during execution. type CopTasksDetails struct { NumCopTasks int diff --git a/sessiontxn/staleread/processor.go b/sessiontxn/staleread/processor.go index 17df59c2873e3..62db35c74ace6 100644 --- a/sessiontxn/staleread/processor.go +++ b/sessiontxn/staleread/processor.go @@ -16,7 +16,6 @@ package staleread import ( "context" - "github.com/pingcap/errors" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/infoschema" @@ -280,7 +279,7 @@ func parseAndValidateAsOf(ctx context.Context, sctx sessionctx.Context, asOf *as return 0, nil } - ts, err := CalculateAsOfTsExpr(sctx, asOf.TsExpr) + ts, err := CalculateAsOfTsExpr(ctx, sctx, asOf.TsExpr) if err != nil { return 0, err } diff --git a/sessiontxn/staleread/util.go b/sessiontxn/staleread/util.go index 3fa84f72cae0b..220759e770452 100644 --- a/sessiontxn/staleread/util.go +++ b/sessiontxn/staleread/util.go @@ -18,6 +18,7 @@ import ( "context" "time" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/mysql" @@ -29,7 +30,14 @@ import ( ) // CalculateAsOfTsExpr calculates the TsExpr of AsOfClause to get a StartTS. -func CalculateAsOfTsExpr(sctx sessionctx.Context, tsExpr ast.ExprNode) (uint64, error) { +func CalculateAsOfTsExpr(ctx context.Context, sctx sessionctx.Context, tsExpr ast.ExprNode) (uint64, error) { + sctx.GetSessionVars().StmtCtx.SetStaleTSOProvider(func() (uint64, error) { + failpoint.Inject("mockStaleReadTSO", func(val failpoint.Value) (uint64, error) { + return uint64(val.(int)), nil + }) + // this function accepts a context, but we don't need it when there is a valid cached ts. + return sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0) + }) tsVal, err := expression.EvalAstExpr(sctx, tsExpr) if err != nil { return 0, err