Skip to content

Commit

Permalink
session,executor: tiny clean up the runStmt function (pingcap#17911)
Browse files Browse the repository at this point in the history
Co-authored-by: pingcap-github-bot <sre-bot@pingcap.com>
  • Loading branch information
tiancaiamao and sre-bot authored Jun 11, 2020
1 parent b4836c1 commit f139821
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 63 deletions.
6 changes: 4 additions & 2 deletions executor/seqtest/prepared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ func (s *seqTestSuite) TestPrepared(c *C) {
query = "select c1 from prepare_test where c1 = (select c1 from prepare_test where c1 = ?)"
stmtID, _, _, err = tk.Se.PrepareStmt(query)
c.Assert(err, IsNil)
_, err = tk.Se.ExecutePreparedStmt(ctx, stmtID, []types.Datum{types.NewDatum(3)})
rs, err = tk.Se.ExecutePreparedStmt(ctx, stmtID, []types.Datum{types.NewDatum(3)})
c.Assert(err, IsNil)
c.Assert(rs.Close(), IsNil)
tk1.MustExec("insert prepare_test (c1) values (3)")
rs, err = tk.Se.ExecutePreparedStmt(ctx, stmtID, []types.Datum{types.NewDatum(3)})
c.Assert(err, IsNil)
Expand All @@ -118,8 +119,9 @@ func (s *seqTestSuite) TestPrepared(c *C) {
query = "select c1 from prepare_test where c1 in (select c1 from prepare_test where c1 = ?)"
stmtID, _, _, err = tk.Se.PrepareStmt(query)
c.Assert(err, IsNil)
_, err = tk.Se.ExecutePreparedStmt(ctx, stmtID, []types.Datum{types.NewDatum(3)})
rs, err = tk.Se.ExecutePreparedStmt(ctx, stmtID, []types.Datum{types.NewDatum(3)})
c.Assert(err, IsNil)
c.Assert(rs.Close(), IsNil)
tk1.MustExec("insert prepare_test (c1) values (3)")
rs, err = tk.Se.ExecutePreparedStmt(ctx, stmtID, []types.Datum{types.NewDatum(3)})
c.Assert(err, IsNil)
Expand Down
16 changes: 8 additions & 8 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1161,7 +1161,7 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex

// Execute the physical plan.
logStmt(stmtNode, s.sessionVars)
recordSet, err := runStmtWrap(ctx, s, stmt)
recordSet, err := runStmt(ctx, s, stmt)
if err != nil {
if !kv.ErrKeyExists.Equal(err) {
logutil.Logger(ctx).Warn("run statement failed",
Expand All @@ -1174,8 +1174,8 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex
return recordSet, nil
}

// runStmtWrap executes the sqlexec.Statement and commit or rollback the current transaction.
func runStmtWrap(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec.RecordSet, err error) {
// runStmt executes the sqlexec.Statement and commit or rollback the current transaction.
func runStmt(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec.RecordSet, err error) {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span1 := span.Tracer().StartSpan("session.runStmt", opentracing.ChildOf(span.Context()))
span1.LogKV("sql", s.OriginText())
Expand Down Expand Up @@ -1281,7 +1281,7 @@ func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields
return prepareExec.ID, prepareExec.ParamCount, prepareExec.Fields, nil
}

func (s *session) CommonExec(ctx context.Context,
func (s *session) preparedStmtExec(ctx context.Context,
stmtID uint32, prepareStmt *plannercore.CachedPrepareStmt, args []types.Datum) (sqlexec.RecordSet, error) {
st, err := executor.CompileExecutePreparedStmt(ctx, s, stmtID, args)
if err != nil {
Expand All @@ -1292,8 +1292,8 @@ func (s *session) CommonExec(ctx context.Context,
return runStmt(ctx, s, st)
}

// CachedPlanExec short path currently ONLY for cached "point select plan" execution
func (s *session) CachedPlanExec(ctx context.Context,
// cachedPlanExec short path currently ONLY for cached "point select plan" execution
func (s *session) cachedPlanExec(ctx context.Context,
stmtID uint32, prepareStmt *plannercore.CachedPrepareStmt, args []types.Datum) (sqlexec.RecordSet, error) {
prepared := prepareStmt.PreparedAst
// compile ExecStmt
Expand Down Expand Up @@ -1404,9 +1404,9 @@ func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args [
return nil, err
}
if ok {
return s.CachedPlanExec(ctx, stmtID, preparedStmt, args)
return s.cachedPlanExec(ctx, stmtID, preparedStmt, args)
}
return s.CommonExec(ctx, stmtID, preparedStmt, args)
return s.preparedStmtExec(ctx, stmtID, preparedStmt, args)
}

func (s *session) DropPreparedStmt(stmtID uint32) error {
Expand Down
53 changes: 0 additions & 53 deletions session/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"sync/atomic"
"time"

"github.com/opentracing/opentracing-go"
"github.com/pingcap/errors"
"github.com/pingcap/parser"
"github.com/pingcap/parser/ast"
Expand Down Expand Up @@ -260,58 +259,6 @@ func checkStmtLimit(ctx context.Context, se *session) error {
return err
}

// runStmt executes the sqlexec.Statement and commit or rollback the current transaction.
func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement) (rs sqlexec.RecordSet, err error) {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span1 := span.Tracer().StartSpan("session.runStmt", opentracing.ChildOf(span.Context()))
span1.LogKV("sql", s.OriginText())
defer span1.Finish()
ctx = opentracing.ContextWithSpan(ctx, span1)
}
sctx.SetValue(sessionctx.QueryString, s.OriginText())
if _, ok := s.(*executor.ExecStmt).StmtNode.(ast.DDLNode); ok {
sctx.SetValue(sessionctx.LastExecuteDDL, true)
} else {
sctx.ClearValue(sessionctx.LastExecuteDDL)
}

se := sctx.(*session)
sessVars := se.GetSessionVars()
// Save origTxnCtx here to avoid it reset in the transaction retry.
origTxnCtx := sessVars.TxnCtx
defer func() {
// If it is not a select statement, we record its slow log here,
// then it could include the transaction commit time.
if rs == nil {
s.(*executor.ExecStmt).FinishExecuteStmt(origTxnCtx.StartTS, err == nil, false)
}
}()

err = se.checkTxnAborted(s)
if err != nil {
return nil, err
}
rs, err = s.Exec(ctx)
sessVars.TxnCtx.StatementCount++
if !s.IsReadOnly(sessVars) {
// All the history should be added here.
if err == nil && sessVars.TxnCtx.CouldRetry {
GetHistory(sctx).Add(s, sessVars.StmtCtx)
}

// Handle the stmt commit/rollback.
if se.txn.Valid() {
if err != nil {
sctx.StmtRollback()
} else {
err = sctx.StmtCommit(sctx.GetSessionVars().StmtCtx.MemTracker)
}
}
}
err = finishStmt(ctx, se, err, s)
return rs, err
}

// GetHistory get all stmtHistory in current txn. Exported only for test.
func GetHistory(ctx sessionctx.Context) *StmtHistory {
hist, ok := ctx.GetSessionVars().TxnCtx.History.(*StmtHistory)
Expand Down

0 comments on commit f139821

Please sign in to comment.