diff --git a/session/session.go b/session/session.go index 62933ff417eb8..6a2e3817e1aa5 100644 --- a/session/session.go +++ b/session/session.go @@ -1033,7 +1033,9 @@ func (s *session) Txn(active bool) kv.Transaction { // If Txn() is called later, wait for the future to get a valid txn. txnCap := s.getMembufCap() if err := s.txn.changePendingToValid(txnCap); err != nil { + log.Error("active transaction fail, err = ", err) s.txn.fail = errors.Trace(err) + s.txn.cleanup() } else { s.sessionVars.TxnCtx.StartTS = s.txn.StartTS() } diff --git a/session/session_fail_test.go b/session/session_fail_test.go index 9cbf408eeb500..2f53e010b98c7 100644 --- a/session/session_fail_test.go +++ b/session/session_fail_test.go @@ -14,6 +14,8 @@ package session_test import ( + "context" + gofail "github.com/etcd-io/gofail/runtime" . "github.com/pingcap/check" "github.com/pingcap/tidb/util/testkit" @@ -27,8 +29,22 @@ func (s *testSessionSuite) TestFailStatementCommit(c *C) { tk.MustExec("begin") tk.MustExec("insert into t values (1)") gofail.Enable("github.com/pingcap/tidb/session/mockStmtCommitError", `return(true)`) - tk.MustExec("insert into t values (2)") - _, err := tk.Exec("commit") + _, err := tk.Exec("insert into t values (2)") c.Assert(err, NotNil) - tk.MustQuery(`select * from t`).Check(testkit.Rows()) + tk.MustExec("commit") + tk.MustQuery(`select * from t`).Check(testkit.Rows("1")) +} + +func (s *testSessionSuite) TestGetTSFailDirtyState(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("create table t (id int)") + + ctx := context.Background() + ctx = context.WithValue(ctx, "mockGetTSFail", struct{}{}) + tk.Se.Execute(ctx, "select * from t") + + // Fix a bug that active txn fail set TxnState.fail to error, and then the following write + // affected by this fail flag. + tk.MustExec("insert into t values (1)") + tk.MustQuery(`select * from t`).Check(testkit.Rows("1")) } diff --git a/session/tidb.go b/session/tidb.go index f94ca2ea76227..ba5b3d54bf23f 100644 --- a/session/tidb.go +++ b/session/tidb.go @@ -166,6 +166,18 @@ func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement) } } } + + // There are two known cases that the s.txn.fail is not nil: + // 1. active transaction fail, can't get start ts for example + // 2. transaction too large and StmtCommit fail + // On both cases, we can return error in advance. + if se.txn.fail != nil { + err = se.txn.fail + se.txn.cleanup() + se.txn.fail = nil + return nil, errors.Trace(err) + } + if !sessVars.InTxn() { if err != nil { log.Info("RollbackTxn for ddl/autocommit error.") diff --git a/session/txn.go b/session/txn.go index 63b0cd7412ab1..2d9d89486dbe0 100644 --- a/session/txn.go +++ b/session/txn.go @@ -16,6 +16,7 @@ package session import ( "context" "fmt" + "runtime/debug" "strings" "github.com/opentracing/opentracing-go" @@ -157,6 +158,16 @@ func (st *TxnState) Commit(ctx context.Context) error { st.fail = nil return errors.Trace(err) } + if len(st.mutations) != 0 || len(st.dirtyTableOP) != 0 || st.buf.Len() != 0 { + log.Errorf("The code should never run here, TxnState=%#v, mutations=%#v, dirtyTableOP=%#v, buf=%#v something must be wrong: %s", + st, + st.mutations, + st.dirtyTableOP, + st.buf, + debug.Stack()) + st.cleanup() + return errors.New("invalid transaction") + } return errors.Trace(st.Transaction.Commit(ctx)) } @@ -273,9 +284,15 @@ func mergeToDirtyDB(dirtyDB *executor.DirtyDB, op dirtyTableOperation) { type txnFuture struct { future oracle.Future store kv.Storage + + mockFail bool } func (tf *txnFuture) wait() (kv.Transaction, error) { + if tf.mockFail { + return nil, errors.New("mock get timestamp fail") + } + startTS, err := tf.future.Wait() if err == nil { return tf.store.BeginWithStartTS(startTS) @@ -293,7 +310,11 @@ func (s *session) getTxnFuture(ctx context.Context) *txnFuture { oracleStore := s.store.GetOracle() tsFuture := oracleStore.GetTimestampAsync(ctx) - return &txnFuture{tsFuture, s.store} + ret := &txnFuture{future: tsFuture, store: s.store} + if x := ctx.Value("mockGetTSFail"); x != nil { + ret.mockFail = true + } + return ret } // StmtCommit implements the sessionctx.Context interface. diff --git a/table/tables/tables_test.go b/table/tables/tables_test.go index 3b1f2fc69e9ac..52232b67d90b2 100644 --- a/table/tables/tables_test.go +++ b/table/tables/tables_test.go @@ -289,6 +289,7 @@ func (ts *testSuite) TestUnsignedPK(c *C) { c.Assert(err, IsNil) c.Assert(len(row), Equals, 2) c.Assert(row[0].Kind(), Equals, types.KindUint64) + ts.se.StmtCommit() c.Assert(ts.se.Txn(true).Commit(context.Background()), IsNil) }