diff --git a/context/context.go b/context/context.go index ee263df465724..ce50221f44046 100644 --- a/context/context.go +++ b/context/context.go @@ -22,18 +22,17 @@ import ( // Context is an interface for transaction and executive args environment. type Context interface { - // GetTxn gets a transaction for further execution. - GetTxn(forceNew bool) (kv.Transaction, error) + // NewTxn creates a new transaction for further execution. + // If old transaction is valid, it is committed first. + // It's used in BEGIN statement and DDL statements to commit old transaction. + NewTxn() error + + // Txn returns the current transaction which is created before executing a statement. + Txn() kv.Transaction // GetClient gets a kv.Client. GetClient() kv.Client - // RollbackTxn rolls back the current transaction. - RollbackTxn() error - - // CommitTxn commits the current transaction. - CommitTxn() error - // SetValue saves a value associated with this context for key. SetValue(key fmt.Stringer, value interface{}) diff --git a/ddl/column_change_test.go b/ddl/column_change_test.go index 6560db0061e1b..d4973297d6e1f 100644 --- a/ddl/column_change_test.go +++ b/ddl/column_change_test.go @@ -58,7 +58,7 @@ func (s *testColumnChangeSuite) TestColumnChange(c *C) { // create table t (c1 int, c2 int); tblInfo := testTableInfo(c, d, "t", 2) ctx := testNewContext(c, d) - _, err := ctx.GetTxn(true) + err := ctx.NewTxn() c.Assert(err, IsNil) testCreateTable(c, ctx, d, s.dbInfo, tblInfo) // insert t values (1, 2); @@ -66,7 +66,7 @@ func (s *testColumnChangeSuite) TestColumnChange(c *C) { row := types.MakeDatums(1, 2) _, err = originTable.AddRecord(ctx, row) c.Assert(err, IsNil) - err = ctx.CommitTxn() + err = ctx.Txn().Commit() c.Assert(err, IsNil) var mu sync.Mutex @@ -87,6 +87,10 @@ func (s *testColumnChangeSuite) TestColumnChange(c *C) { hookCtx.Store = s.store prevState = job.SchemaState var err error + err = hookCtx.NewTxn() + if err != nil { + checkErr = errors.Trace(err) + } switch job.SchemaState { case model.StateDeleteOnly: deleteOnlyTable, err = getCurrentTable(d, s.dbInfo.ID, tblInfo.ID) @@ -114,6 +118,10 @@ func (s *testColumnChangeSuite) TestColumnChange(c *C) { } mu.Unlock() } + err = hookCtx.Txn().Commit() + if err != nil { + checkErr = errors.Trace(err) + } } d.setHook(tc) defaultValue := int64(3) @@ -142,6 +150,10 @@ func (s *testColumnChangeSuite) testAddColumnNoDefault(c *C, ctx context.Context hookCtx.Store = s.store prevState = job.SchemaState var err error + err = hookCtx.NewTxn() + if err != nil { + checkErr = errors.Trace(err) + } switch job.SchemaState { case model.StateWriteOnly: writeOnlyTable, err = getCurrentTable(d, s.dbInfo.ID, tblInfo.ID) @@ -158,6 +170,10 @@ func (s *testColumnChangeSuite) testAddColumnNoDefault(c *C, ctx context.Context checkErr = errors.Trace(err) } } + err = hookCtx.Txn().Commit() + if err != nil { + checkErr = errors.Trace(err) + } } d.setHook(tc) d.start() @@ -196,11 +212,15 @@ func (s *testColumnChangeSuite) testColumnDrop(c *C, ctx context.Context, d *ddl func (s *testColumnChangeSuite) checkAddWriteOnly(d *ddl, ctx context.Context, deleteOnlyTable, writeOnlyTable table.Table) error { // WriteOnlyTable: insert t values (2, 3) - _, err := writeOnlyTable.AddRecord(ctx, types.MakeDatums(2, 3)) + err := ctx.NewTxn() if err != nil { return errors.Trace(err) } - err = ctx.CommitTxn() + _, err = writeOnlyTable.AddRecord(ctx, types.MakeDatums(2, 3)) + if err != nil { + return errors.Trace(err) + } + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -222,7 +242,7 @@ func (s *testColumnChangeSuite) checkAddWriteOnly(d *ddl, ctx context.Context, d if err != nil { return errors.Trace(err) } - err = ctx.CommitTxn() + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -236,7 +256,7 @@ func (s *testColumnChangeSuite) checkAddWriteOnly(d *ddl, ctx context.Context, d if err != nil { return errors.Trace(err) } - err = ctx.CommitTxn() + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -255,11 +275,15 @@ func touchedMap(t table.Table) map[int]bool { func (s *testColumnChangeSuite) checkAddPublic(d *ddl, ctx context.Context, writeOnlyTable, publicTable table.Table) error { // publicTable Insert t values (4, 4, 4) + err := ctx.NewTxn() + if err != nil { + return errors.Trace(err) + } h, err := publicTable.AddRecord(ctx, types.MakeDatums(4, 4, 4)) if err != nil { return errors.Trace(err) } - err = ctx.CommitTxn() + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -276,7 +300,7 @@ func (s *testColumnChangeSuite) checkAddPublic(d *ddl, ctx context.Context, writ if err != nil { return errors.Trace(err) } - err = ctx.CommitTxn() + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } diff --git a/ddl/column_test.go b/ddl/column_test.go index e11333ddb6f62..a508a7e24c749 100644 --- a/ddl/column_test.go +++ b/ddl/column_test.go @@ -104,7 +104,6 @@ func (s *testColumnSuite) TestColumn(c *C) { defer testleak.AfterTest(c)() tblInfo := testTableInfo(c, s.d, "t1", 3) ctx := testNewContext(c, s.d) - defer ctx.RollbackTxn() testCreateTable(c, ctx, s.d, s.dbInfo, tblInfo) @@ -116,7 +115,7 @@ func (s *testColumnSuite) TestColumn(c *C) { c.Assert(err, IsNil) } - err := ctx.CommitTxn() + err := ctx.NewTxn() c.Assert(err, IsNil) i := int64(0) @@ -152,7 +151,7 @@ func (s *testColumnSuite) TestColumn(c *C) { h, err := t.AddRecord(ctx, types.MakeDatums(11, 12, 13, 14)) c.Assert(err, IsNil) - err = ctx.CommitTxn() + err = ctx.NewTxn() c.Assert(err, IsNil) values, err := t.RowWithCols(ctx, h, t.Cols()) c.Assert(err, IsNil) @@ -254,13 +253,13 @@ func (s *testColumnSuite) TestColumn(c *C) { } func (s *testColumnSuite) checkColumnKVExist(ctx context.Context, t table.Table, handle int64, col *table.Column, columnValue interface{}, isExist bool) error { - txn, err := ctx.GetTxn(true) + err := ctx.NewTxn() if err != nil { return errors.Trace(err) } - defer ctx.CommitTxn() + defer ctx.Txn().Commit() key := t.RecordKey(handle) - data, err := txn.Get(key) + data, err := ctx.Txn().Get(key) if !isExist { if terror.ErrorEqual(err, kv.ErrNotExist) { return nil @@ -303,7 +302,7 @@ func (s *testColumnSuite) checkNoneColumn(c *C, ctx context.Context, d *ddl, tbl func (s *testColumnSuite) checkDeleteOnlyColumn(c *C, ctx context.Context, d *ddl, tblInfo *model.TableInfo, handle int64, col *table.Column, row []types.Datum, columnValue interface{}) error { t := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) - _, err := ctx.GetTxn(true) + err := ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -326,7 +325,7 @@ func (s *testColumnSuite) checkDeleteOnlyColumn(c *C, ctx context.Context, d *dd return errors.Trace(err) } // Test add a new row. - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -336,8 +335,7 @@ func (s *testColumnSuite) checkDeleteOnlyColumn(c *C, ctx context.Context, d *dd if err != nil { return errors.Trace(err) } - - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -364,7 +362,7 @@ func (s *testColumnSuite) checkDeleteOnlyColumn(c *C, ctx context.Context, d *dd return errors.Trace(err) } // Test remove a row. - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -373,7 +371,7 @@ func (s *testColumnSuite) checkDeleteOnlyColumn(c *C, ctx context.Context, d *dd if err != nil { return errors.Trace(err) } - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -403,7 +401,7 @@ func (s *testColumnSuite) checkDeleteOnlyColumn(c *C, ctx context.Context, d *dd func (s *testColumnSuite) checkWriteOnlyColumn(c *C, ctx context.Context, d *ddl, tblInfo *model.TableInfo, handle int64, col *table.Column, row []types.Datum, columnValue interface{}) error { t := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) - _, err := ctx.GetTxn(true) + err := ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -429,7 +427,7 @@ func (s *testColumnSuite) checkWriteOnlyColumn(c *C, ctx context.Context, d *ddl } // Test add a new row. - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -439,8 +437,7 @@ func (s *testColumnSuite) checkWriteOnlyColumn(c *C, ctx context.Context, d *ddl if err != nil { return errors.Trace(err) } - - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -467,7 +464,7 @@ func (s *testColumnSuite) checkWriteOnlyColumn(c *C, ctx context.Context, d *ddl return errors.Trace(err) } // Test remove a row. - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -476,7 +473,7 @@ func (s *testColumnSuite) checkWriteOnlyColumn(c *C, ctx context.Context, d *ddl if err != nil { return errors.Trace(err) } - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -507,7 +504,7 @@ func (s *testColumnSuite) checkWriteOnlyColumn(c *C, ctx context.Context, d *ddl func (s *testColumnSuite) checkReorganizationColumn(c *C, ctx context.Context, d *ddl, tblInfo *model.TableInfo, handle int64, col *table.Column, row []types.Datum, columnValue interface{}) error { t := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) - _, err := ctx.GetTxn(true) + err := ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -528,7 +525,7 @@ func (s *testColumnSuite) checkReorganizationColumn(c *C, ctx context.Context, d } // Test add a new row. - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -538,7 +535,7 @@ func (s *testColumnSuite) checkReorganizationColumn(c *C, ctx context.Context, d if err != nil { return errors.Trace(err) } - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -566,7 +563,7 @@ func (s *testColumnSuite) checkReorganizationColumn(c *C, ctx context.Context, d } // Test remove a row. - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -575,8 +572,7 @@ func (s *testColumnSuite) checkReorganizationColumn(c *C, ctx context.Context, d if err != nil { return errors.Trace(err) } - - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -602,7 +598,7 @@ func (s *testColumnSuite) checkReorganizationColumn(c *C, ctx context.Context, d func (s *testColumnSuite) checkPublicColumn(c *C, ctx context.Context, d *ddl, tblInfo *model.TableInfo, handle int64, newCol *table.Column, oldRow []types.Datum, columnValue interface{}) error { t := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) - _, err := ctx.GetTxn(true) + err := ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -624,7 +620,7 @@ func (s *testColumnSuite) checkPublicColumn(c *C, ctx context.Context, d *ddl, t } // Test add a new row. - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -634,8 +630,7 @@ func (s *testColumnSuite) checkPublicColumn(c *C, ctx context.Context, d *ddl, t if err != nil { return errors.Trace(err) } - - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -655,7 +650,7 @@ func (s *testColumnSuite) checkPublicColumn(c *C, ctx context.Context, d *ddl, t } // Test remove a row. - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -665,7 +660,7 @@ func (s *testColumnSuite) checkPublicColumn(c *C, ctx context.Context, d *ddl, t return errors.Trace(err) } - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -683,10 +678,6 @@ func (s *testColumnSuite) checkPublicColumn(c *C, ctx context.Context, d *ddl, t return errors.Errorf("expect 1, got %v", i) } - err = ctx.CommitTxn() - if err != nil { - return errors.Trace(err) - } err = s.testGetColumn(t, newCol.Name.L, true) if err != nil { return errors.Trace(err) @@ -732,7 +723,7 @@ func (s *testColumnSuite) TestAddColumn(c *C) { tblInfo := testTableInfo(c, d, "t", 3) ctx := testNewContext(c, d) - _, err := ctx.GetTxn(true) + err := ctx.NewTxn() c.Assert(err, IsNil) testCreateTable(c, ctx, d, s.dbInfo, tblInfo) @@ -742,7 +733,7 @@ func (s *testColumnSuite) TestAddColumn(c *C) { handle, err := t.AddRecord(ctx, oldRow) c.Assert(err, IsNil) - err = ctx.CommitTxn() + err = ctx.Txn().Commit() c.Assert(err, IsNil) newColName := "c4" @@ -784,13 +775,13 @@ func (s *testColumnSuite) TestAddColumn(c *C) { c.Assert(errors.ErrorStack(checkErr), Equals, "") testCheckJobDone(c, d, job, true) - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() c.Assert(err, IsNil) job = testDropTable(c, ctx, d, s.dbInfo, tblInfo) testCheckJobDone(c, d, job, false) - err = ctx.CommitTxn() + err = ctx.Txn().Commit() c.Assert(err, IsNil) d.close() @@ -803,7 +794,7 @@ func (s *testColumnSuite) TestDropColumn(c *C) { tblInfo := testTableInfo(c, d, "t", 4) ctx := testNewContext(c, d) - _, err := ctx.GetTxn(true) + err := ctx.NewTxn() c.Assert(err, IsNil) testCreateTable(c, ctx, d, s.dbInfo, tblInfo) @@ -816,7 +807,7 @@ func (s *testColumnSuite) TestDropColumn(c *C) { _, err = t.AddRecord(ctx, append(row, types.NewDatum(defaultColValue))) c.Assert(err, IsNil) - err = ctx.CommitTxn() + err = ctx.Txn().Commit() c.Assert(err, IsNil) checkOK := false @@ -851,13 +842,13 @@ func (s *testColumnSuite) TestDropColumn(c *C) { c.Assert(checkOK, IsTrue) mu.Unlock() - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() c.Assert(err, IsNil) job = testDropTable(c, ctx, d, s.dbInfo, tblInfo) testCheckJobDone(c, d, job, false) - err = ctx.CommitTxn() + err = ctx.Txn().Commit() c.Assert(err, IsNil) d.close() diff --git a/ddl/ddl.go b/ddl/ddl.go index 0e5a76c8e2517..c93df16a8600e 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -1408,7 +1408,7 @@ func (d *ddl) DropIndex(ctx context.Context, ti ast.Ident, indexName model.CIStr func (d *ddl) doDDLJob(ctx context.Context, job *model.Job) error { // For every DDL, we must commit current transaction. - if err := ctx.CommitTxn(); err != nil { + if err := ctx.NewTxn(); err != nil { return errors.Trace(err) } diff --git a/ddl/ddl_db_test.go b/ddl/ddl_db_test.go index 76b6e5599f6c1..708bec8b860b1 100644 --- a/ddl/ddl_db_test.go +++ b/ddl/ddl_db_test.go @@ -348,6 +348,7 @@ LOOP: // get all row handles ctx := s.s.(context.Context) + c.Assert(ctx.NewTxn(), IsNil) t := s.testGetTable(c, "t1") handles := make(map[int64]struct{}) err := t.IterRecords(ctx, t.FirstKey(), t.Cols(), @@ -368,11 +369,10 @@ LOOP: // Make sure there is index with name c3_index. c.Assert(nidx, NotNil) c.Assert(nidx.Meta().ID, Greater, int64(0)) - txn, err := ctx.GetTxn(true) - c.Assert(err, IsNil) - defer ctx.RollbackTxn() + c.Assert(ctx.NewTxn(), IsNil) + defer ctx.Txn().Rollback() - it, err := nidx.SeekFirst(txn) + it, err := nidx.SeekFirst(ctx.Txn()) c.Assert(err, IsNil) defer it.Close() @@ -454,11 +454,10 @@ LOOP: // Make sure there is no index with name c3_index. c.Assert(nidx, IsNil) idx := tables.NewIndex(t.Meta(), c3idx.Meta()) - txn, err := ctx.GetTxn(true) - c.Assert(err, IsNil) - defer ctx.RollbackTxn() + c.Assert(ctx.NewTxn(), IsNil) + defer ctx.Txn().Rollback() - it, err := idx.SeekFirst(txn) + it, err := idx.SeekFirst(ctx.Txn()) c.Assert(err, IsNil) defer it.Close() @@ -590,7 +589,8 @@ LOOP: t := s.testGetTable(c, "t2") i := 0 j := 0 - defer ctx.RollbackTxn() + ctx.NewTxn() + defer ctx.Txn().Rollback() err := t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { i++ @@ -622,8 +622,6 @@ func (s *testDBSuite) testDropColumn(c *C) { } // get c4 column id - ctx := s.s.(context.Context) - go func() { sessionExec(c, s.store, "alter table t2 drop column c4") done <- struct{}{} @@ -666,10 +664,6 @@ LOOP: count, ok := rows[0][0].(int64) c.Assert(ok, IsTrue) c.Assert(count, Greater, int64(0)) - - _, err := ctx.GetTxn(true) - c.Assert(err, IsNil) - ctx.CommitTxn() } func (s *testDBSuite) testChangeColumn(c *C) { diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index 6482abfe5ea05..4900a6ffb9cd7 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -48,10 +48,9 @@ func testNewContext(c *C, d *ddl) context.Context { } func getSchemaVer(c *C, ctx context.Context) int64 { - txn, err := ctx.GetTxn(true) + err := ctx.NewTxn() c.Assert(err, IsNil) - c.Assert(txn, NotNil) - m := meta.NewMeta(txn) + m := meta.NewMeta(ctx.Txn()) ver, err := m.GetSchemaVersion() c.Assert(err, IsNil) return ver @@ -75,9 +74,8 @@ func checkEqualTable(c *C, t1, t2 *model.TableInfo) { } func checkHistoryJobArgs(c *C, ctx context.Context, id int64, args *historyJobArgs) { - txn, err := ctx.GetTxn(true) - c.Assert(err, IsNil) - t := meta.NewMeta(txn) + c.Assert(ctx.NewTxn(), IsNil) + t := meta.NewMeta(ctx.Txn()) historyJob, err := t.GetHistoryDDLJob(id) c.Assert(err, IsNil) diff --git a/ddl/foreign_key_test.go b/ddl/foreign_key_test.go index 3d3561f205b83..7dc4976d62321 100644 --- a/ddl/foreign_key_test.go +++ b/ddl/foreign_key_test.go @@ -124,12 +124,12 @@ func (s *testForeighKeySuite) TestForeignKey(c *C) { testCreateSchema(c, ctx, d, s.dbInfo) tblInfo := testTableInfo(c, d, "t", 3) - _, err := ctx.GetTxn(true) + err := ctx.NewTxn() c.Assert(err, IsNil) testCreateTable(c, ctx, d, s.dbInfo, tblInfo) - err = ctx.CommitTxn() + err = ctx.Txn().Commit() c.Assert(err, IsNil) // fix data race @@ -155,7 +155,7 @@ func (s *testForeighKeySuite) TestForeignKey(c *C) { job := s.testCreateForeignKey(c, tblInfo, "c1_fk", []string{"c1"}, "t2", []string{"c1"}, ast.ReferOptionCascade, ast.ReferOptionSetNull) testCheckJobDone(c, d, job, true) - err = ctx.CommitTxn() + err = ctx.Txn().Commit() c.Assert(err, IsNil) mu.Lock() c.Assert(checkOK, IsTrue) @@ -186,7 +186,7 @@ func (s *testForeighKeySuite) TestForeignKey(c *C) { c.Assert(checkOK, IsTrue) mu.Unlock() - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() c.Assert(err, IsNil) tc.onJobUpdated = func(job *model.Job) { @@ -198,7 +198,7 @@ func (s *testForeighKeySuite) TestForeignKey(c *C) { job = testDropTable(c, ctx, d, s.dbInfo, tblInfo) testCheckJobDone(c, d, job, false) - err = ctx.CommitTxn() + err = ctx.Txn().Commit() c.Assert(err, IsNil) d.close() diff --git a/ddl/index_change_test.go b/ddl/index_change_test.go index 679e866a2f47e..6930fdd0e0ad0 100644 --- a/ddl/index_change_test.go +++ b/ddl/index_change_test.go @@ -54,7 +54,7 @@ func (s *testIndexChangeSuite) TestIndexChange(c *C) { tblInfo.Columns[0].Flag = mysql.PriKeyFlag | mysql.NotNullFlag tblInfo.PKIsHandle = true ctx := testNewContext(c, d) - _, err := ctx.GetTxn(true) + err := ctx.NewTxn() c.Assert(err, IsNil) testCreateTable(c, ctx, d, s.dbInfo, tblInfo) originTable := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) @@ -67,7 +67,7 @@ func (s *testIndexChangeSuite) TestIndexChange(c *C) { _, err = originTable.AddRecord(ctx, types.MakeDatums(3, 3)) c.Assert(err, IsNil) - err = ctx.CommitTxn() + err = ctx.Txn().Commit() c.Assert(err, IsNil) tc := &testDDLCallback{} @@ -115,7 +115,7 @@ func (s *testIndexChangeSuite) TestIndexChange(c *C) { d.setHook(tc) testCreateIndex(c, ctx, d, s.dbInfo, originTable.Meta(), false, "c2", "c2") c.Check(errors.ErrorStack(checkErr), Equals, "") - + c.Assert(ctx.Txn().Commit(), IsNil) d.Stop() prevState = model.StatePublic var noneTable table.Table @@ -125,13 +125,14 @@ func (s *testIndexChangeSuite) TestIndexChange(c *C) { } prevState = job.SchemaState var err error + ctx1 := testNewContext(c, d) switch job.SchemaState { case model.StateWriteOnly: writeOnlyTable, err = getCurrentTable(d, s.dbInfo.ID, tblInfo.ID) if err != nil { checkErr = errors.Trace(err) } - err = s.checkDropWriteOnly(d, ctx, publicTable, writeOnlyTable) + err = s.checkDropWriteOnly(d, ctx1, publicTable, writeOnlyTable) if err != nil { checkErr = errors.Trace(err) } @@ -140,7 +141,7 @@ func (s *testIndexChangeSuite) TestIndexChange(c *C) { if err != nil { checkErr = errors.Trace(err) } - err = s.checkDropDeleteOnly(d, ctx, writeOnlyTable, deleteOnlyTable) + err = s.checkDropDeleteOnly(d, ctx1, writeOnlyTable, deleteOnlyTable) if err != nil { checkErr = errors.Trace(err) } @@ -160,12 +161,8 @@ func (s *testIndexChangeSuite) TestIndexChange(c *C) { } func checkIndexExists(ctx context.Context, tbl table.Table, indexValue interface{}, handle int64, exists bool) error { - txn, err := ctx.GetTxn(true) - if err != nil { - return errors.Trace(err) - } idx := tbl.Indices()[0] - doesExist, _, err := idx.Exist(txn, types.MakeDatums(indexValue), handle) + doesExist, _, err := idx.Exist(ctx.Txn(), types.MakeDatums(indexValue), handle) if err != nil { return errors.Trace(err) } @@ -180,7 +177,11 @@ func checkIndexExists(ctx context.Context, tbl table.Table, indexValue interface func (s *testIndexChangeSuite) checkAddWriteOnly(d *ddl, ctx context.Context, delOnlyTbl, writeOnlyTbl table.Table) error { // DeleteOnlyTable: insert t values (4, 4); - _, err := delOnlyTbl.AddRecord(ctx, types.MakeDatums(4, 4)) + err := ctx.NewTxn() + if err != nil { + return errors.Trace(err) + } + _, err = delOnlyTbl.AddRecord(ctx, types.MakeDatums(4, 4)) if err != nil { return errors.Trace(err) } @@ -249,7 +250,11 @@ func (s *testIndexChangeSuite) checkAddWriteOnly(d *ddl, ctx context.Context, de func (s *testIndexChangeSuite) checkAddPublic(d *ddl, ctx context.Context, writeTbl, publicTbl table.Table) error { // WriteOnlyTable: insert t values (6, 6) - _, err := writeTbl.AddRecord(ctx, types.MakeDatums(6, 6)) + err := ctx.NewTxn() + if err != nil { + return errors.Trace(err) + } + _, err = writeTbl.AddRecord(ctx, types.MakeDatums(6, 6)) if err != nil { return errors.Trace(err) } @@ -307,12 +312,16 @@ func (s *testIndexChangeSuite) checkAddPublic(d *ddl, ctx context.Context, write return errors.Trace(err) } } - return nil + return ctx.Txn().Commit() } func (s *testIndexChangeSuite) checkDropWriteOnly(d *ddl, ctx context.Context, publicTbl, writeTbl table.Table) error { // WriteOnlyTable insert t values (8, 8) - _, err := writeTbl.AddRecord(ctx, types.MakeDatums(8, 8)) + err := ctx.NewTxn() + if err != nil { + return errors.Trace(err) + } + _, err = writeTbl.AddRecord(ctx, types.MakeDatums(8, 8)) if err != nil { return errors.Trace(err) } @@ -343,12 +352,16 @@ func (s *testIndexChangeSuite) checkDropWriteOnly(d *ddl, ctx context.Context, p if err != nil { return errors.Trace(err) } - return nil + return ctx.Txn().Commit() } func (s *testIndexChangeSuite) checkDropDeleteOnly(d *ddl, ctx context.Context, writeTbl, delTbl table.Table) error { // WriteOnlyTable insert t values (9, 9) - _, err := writeTbl.AddRecord(ctx, types.MakeDatums(9, 9)) + err := ctx.NewTxn() + if err != nil { + return errors.Trace(err) + } + _, err = writeTbl.AddRecord(ctx, types.MakeDatums(9, 9)) if err != nil { return errors.Trace(err) } @@ -384,5 +397,5 @@ func (s *testIndexChangeSuite) checkDropDeleteOnly(d *ddl, ctx context.Context, if err != nil { return errors.Trace(err) } - return nil + return ctx.Txn().Commit() } diff --git a/ddl/index_test.go b/ddl/index_test.go index b0a9c5f42f5c6..211680e7b17b4 100644 --- a/ddl/index_test.go +++ b/ddl/index_test.go @@ -86,23 +86,19 @@ func (s *testIndexSuite) TestIndex(c *C) { defer testleak.AfterTest(c)() tblInfo := testTableInfo(c, s.d, "t1", 3) ctx := testNewContext(c, s.d) - defer ctx.RollbackTxn() - - txn, err := ctx.GetTxn(true) - c.Assert(err, IsNil) testCreateTable(c, ctx, s.d, s.dbInfo, tblInfo) t := testGetTable(c, s.d, s.dbInfo.ID, tblInfo.ID) - + err := ctx.NewTxn() + c.Assert(err, IsNil) num := 10 for i := 0; i < num; i++ { _, err = t.AddRecord(ctx, types.MakeDatums(i, i, i)) c.Assert(err, IsNil) } - err = ctx.CommitTxn() - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) i := int64(0) t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { @@ -118,6 +114,7 @@ func (s *testIndexSuite) TestIndex(c *C) { index := tables.FindIndexByColName(t, "c1") c.Assert(index, NotNil) + c.Assert(ctx.NewTxn(), IsNil) h, err := t.AddRecord(ctx, types.MakeDatums(num+1, 1, 1)) c.Assert(err, IsNil) @@ -128,10 +125,10 @@ func (s *testIndexSuite) TestIndex(c *C) { h, err = t.AddRecord(ctx, types.MakeDatums(1, 1, 1)) c.Assert(err, NotNil) - txn, err = ctx.GetTxn(true) + err = ctx.NewTxn() c.Assert(err, IsNil) - exist, _, err := index.Exist(txn, types.MakeDatums(1), h) + exist, _, err := index.Exist(ctx.Txn(), types.MakeDatums(1), h) c.Assert(err, IsNil) c.Assert(exist, IsTrue) @@ -142,10 +139,10 @@ func (s *testIndexSuite) TestIndex(c *C) { index1 := tables.FindIndexByColName(t, "c1") c.Assert(index1, IsNil) - txn, err = ctx.GetTxn(true) + err = ctx.NewTxn() c.Assert(err, IsNil) - exist, _, err = index.Exist(txn, types.MakeDatums(1), h) + exist, _, err = index.Exist(ctx.Txn(), types.MakeDatums(1), h) c.Assert(err, IsNil) c.Assert(exist, IsFalse) @@ -176,15 +173,12 @@ func (s *testIndexSuite) testGetIndex(c *C, t table.Table, name string, isExist func (s *testIndexSuite) checkIndexKVExist(c *C, ctx context.Context, t table.Table, handle int64, indexCol table.Index, columnValues []types.Datum, isExist bool) { c.Assert(len(indexCol.Meta().Columns), Equals, len(columnValues)) - txn, err := ctx.GetTxn(true) + err := ctx.NewTxn() c.Assert(err, IsNil) - exist, _, err := indexCol.Exist(txn, columnValues, handle) + exist, _, err := indexCol.Exist(ctx.Txn(), columnValues, handle) c.Assert(err, IsNil) c.Assert(exist, Equals, isExist) - - err = ctx.CommitTxn() - c.Assert(err, IsNil) } func (s *testIndexSuite) checkNoneIndex(c *C, ctx context.Context, d *ddl, tblInfo *model.TableInfo, handle int64, index table.Index, row []types.Datum) { @@ -201,12 +195,10 @@ func (s *testIndexSuite) checkNoneIndex(c *C, ctx context.Context, d *ddl, tblIn func (s *testIndexSuite) checkDeleteOnlyIndex(c *C, ctx context.Context, d *ddl, tblInfo *model.TableInfo, handle int64, index table.Index, row []types.Datum, isDropped bool) { t := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) - - _, err := ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) i := int64(0) - err = t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { + err := t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { c.Assert(data, DeepEquals, row) i++ return true, nil @@ -222,15 +214,13 @@ func (s *testIndexSuite) checkDeleteOnlyIndex(c *C, ctx context.Context, d *ddl, s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, isDropped) // Test add a new row. - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) newRow := types.MakeDatums(int64(11), int64(22), int64(33)) handle, err = t.AddRecord(ctx, newRow) c.Assert(err, IsNil) - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) rows := [][]types.Datum{row, newRow} @@ -249,8 +239,7 @@ func (s *testIndexSuite) checkDeleteOnlyIndex(c *C, ctx context.Context, d *ddl, s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, false) // Test update a new row. - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) newUpdateRow := types.MakeDatums(int64(44), int64(55), int64(66)) touched := map[int]bool{0: true, 1: true, 2: true} @@ -266,14 +255,11 @@ func (s *testIndexSuite) checkDeleteOnlyIndex(c *C, ctx context.Context, d *ddl, s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, false) // Test remove a row. - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) err = t.RemoveRecord(ctx, handle, newUpdateRow) c.Assert(err, IsNil) - - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) i = int64(0) t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { @@ -289,11 +275,10 @@ func (s *testIndexSuite) checkDeleteOnlyIndex(c *C, ctx context.Context, d *ddl, func (s *testIndexSuite) checkWriteOnlyIndex(c *C, ctx context.Context, d *ddl, tblInfo *model.TableInfo, handle int64, index table.Index, row []types.Datum, isDropped bool) { t := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) - _, err := ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) i := int64(0) - err = t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { + err := t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { c.Assert(data, DeepEquals, row) i++ return true, nil @@ -309,15 +294,13 @@ func (s *testIndexSuite) checkWriteOnlyIndex(c *C, ctx context.Context, d *ddl, s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, isDropped) // Test add a new row. - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) newRow := types.MakeDatums(int64(11), int64(22), int64(33)) handle, err = t.AddRecord(ctx, newRow) c.Assert(err, IsNil) - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) rows := [][]types.Datum{row, newRow} @@ -336,7 +319,7 @@ func (s *testIndexSuite) checkWriteOnlyIndex(c *C, ctx context.Context, d *ddl, s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, true) // Test update a new row. - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() c.Assert(err, IsNil) newUpdateRow := types.MakeDatums(int64(44), int64(55), int64(66)) @@ -353,13 +336,13 @@ func (s *testIndexSuite) checkWriteOnlyIndex(c *C, ctx context.Context, d *ddl, s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, true) // Test remove a row. - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() c.Assert(err, IsNil) err = t.RemoveRecord(ctx, handle, newUpdateRow) c.Assert(err, IsNil) - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() c.Assert(err, IsNil) i = int64(0) @@ -376,11 +359,10 @@ func (s *testIndexSuite) checkWriteOnlyIndex(c *C, ctx context.Context, d *ddl, func (s *testIndexSuite) checkReorganizationIndex(c *C, ctx context.Context, d *ddl, tblInfo *model.TableInfo, handle int64, index table.Index, row []types.Datum, isDropped bool) { t := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) - _, err := ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) i := int64(0) - err = t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { + err := t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { c.Assert(data, DeepEquals, row) i++ return true, nil @@ -389,15 +371,13 @@ func (s *testIndexSuite) checkReorganizationIndex(c *C, ctx context.Context, d * c.Assert(i, Equals, int64(1)) // Test add a new row. - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) newRow := types.MakeDatums(int64(11), int64(22), int64(33)) handle, err = t.AddRecord(ctx, newRow) c.Assert(err, IsNil) - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) rows := [][]types.Datum{row, newRow} @@ -417,8 +397,7 @@ func (s *testIndexSuite) checkReorganizationIndex(c *C, ctx context.Context, d * s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, !isDropped) // Test update a new row. - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) newUpdateRow := types.MakeDatums(int64(44), int64(55), int64(66)) touched := map[int]bool{0: true, 1: true, 2: true} @@ -434,14 +413,10 @@ func (s *testIndexSuite) checkReorganizationIndex(c *C, ctx context.Context, d * s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, !isDropped) // Test remove a row. - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) + c.Assert(t.RemoveRecord(ctx, handle, newUpdateRow), IsNil) - err = t.RemoveRecord(ctx, handle, newUpdateRow) - c.Assert(err, IsNil) - - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) i = int64(0) t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { @@ -456,11 +431,10 @@ func (s *testIndexSuite) checkReorganizationIndex(c *C, ctx context.Context, d * func (s *testIndexSuite) checkPublicIndex(c *C, ctx context.Context, d *ddl, tblInfo *model.TableInfo, handle int64, index table.Index, row []types.Datum) { t := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) - _, err := ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) i := int64(0) - err = t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { + err := t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { c.Assert(data, DeepEquals, row) i++ return true, nil @@ -476,15 +450,14 @@ func (s *testIndexSuite) checkPublicIndex(c *C, ctx context.Context, d *ddl, tbl s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, true) // Test add a new row. - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() c.Assert(err, IsNil) newRow := types.MakeDatums(int64(11), int64(22), int64(33)) handle, err = t.AddRecord(ctx, newRow) c.Assert(err, IsNil) - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) rows := [][]types.Datum{row, newRow} @@ -503,13 +476,11 @@ func (s *testIndexSuite) checkPublicIndex(c *C, ctx context.Context, d *ddl, tbl s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, true) // Test update a new row. - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) newUpdateRow := types.MakeDatums(int64(44), int64(55), int64(66)) touched := map[int]bool{0: true, 1: true, 2: true} - err = t.UpdateRecord(ctx, handle, newRow, newUpdateRow, touched) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) + c.Assert(t.UpdateRecord(ctx, handle, newRow, newUpdateRow, touched), IsNil) s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, false) @@ -520,22 +491,16 @@ func (s *testIndexSuite) checkPublicIndex(c *C, ctx context.Context, d *ddl, tbl s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, true) // Test remove a row. - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) - - err = t.RemoveRecord(ctx, handle, newUpdateRow) - c.Assert(err, IsNil) - - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) + c.Assert(t.RemoveRecord(ctx, handle, newUpdateRow), IsNil) + c.Assert(ctx.NewTxn(), IsNil) i = int64(0) t.IterRecords(ctx, t.FirstKey(), t.Cols(), func(h int64, data []types.Datum, cols []*table.Column) (bool, error) { i++ return true, nil }) c.Assert(i, Equals, int64(1)) - s.checkIndexKVExist(c, ctx, t, handle, index, columnValues, false) s.testGetIndex(c, t, index.Meta().Columns[0].Name.L, true) } @@ -562,19 +527,16 @@ func (s *testIndexSuite) TestAddIndex(c *C) { d := newDDL(s.store, nil, nil, testLease) tblInfo := testTableInfo(c, d, "t", 3) ctx := testNewContext(c, d) - - _, err := ctx.GetTxn(true) - c.Assert(err, IsNil) - testCreateTable(c, ctx, d, s.dbInfo, tblInfo) t := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) + c.Assert(ctx.NewTxn(), IsNil) row := types.MakeDatums(int64(1), int64(2), int64(3)) handle, err := t.AddRecord(ctx, row) c.Assert(err, IsNil) - err = ctx.CommitTxn() + err = ctx.NewTxn() c.Assert(err, IsNil) checkOK := false @@ -612,13 +574,13 @@ func (s *testIndexSuite) TestAddIndex(c *C) { job = testCreateIndex(c, ctx, d, s.dbInfo, tblInfo, true, "c1", "c1") testCheckJobDone(c, d, job, true) - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() c.Assert(err, IsNil) job = testDropTable(c, ctx, d, s.dbInfo, tblInfo) testCheckJobDone(c, d, job, false) - err = ctx.CommitTxn() + err = ctx.Txn().Commit() c.Assert(err, IsNil) d.close() @@ -631,7 +593,7 @@ func (s *testIndexSuite) TestDropIndex(c *C) { tblInfo := testTableInfo(c, d, "t", 3) ctx := testNewContext(c, d) - _, err := ctx.GetTxn(true) + err := ctx.NewTxn() c.Assert(err, IsNil) testCreateTable(c, ctx, d, s.dbInfo, tblInfo) @@ -642,15 +604,12 @@ func (s *testIndexSuite) TestDropIndex(c *C) { handle, err := t.AddRecord(ctx, row) c.Assert(err, IsNil) - err = ctx.CommitTxn() + err = ctx.NewTxn() c.Assert(err, IsNil) job := testCreateIndex(c, ctx, s.d, s.dbInfo, tblInfo, true, "c1_uni", "c1") testCheckJobDone(c, d, job, true) - err = ctx.CommitTxn() - c.Assert(err, IsNil) - checkOK := false oldIndexCol := tables.NewIndex(tblInfo, &model.IndexInfo{}) @@ -685,13 +644,13 @@ func (s *testIndexSuite) TestDropIndex(c *C) { job = testDropIndex(c, ctx, d, s.dbInfo, tblInfo, "c1_uni") testCheckJobDone(c, d, job, false) - _, err = ctx.GetTxn(true) + err = ctx.NewTxn() c.Assert(err, IsNil) job = testDropTable(c, ctx, d, s.dbInfo, tblInfo) testCheckJobDone(c, d, job, false) - err = ctx.CommitTxn() + err = ctx.Txn().Commit() c.Assert(err, IsNil) d.close() @@ -706,7 +665,7 @@ func (s *testIndexSuite) TestAddIndexWithNullColumn(c *C) { tblInfo.Columns[1].DefaultValue = nil ctx := testNewContext(c, d) - _, err := ctx.GetTxn(true) + err := ctx.NewTxn() c.Assert(err, IsNil) testCreateTable(c, ctx, d, s.dbInfo, tblInfo) @@ -718,7 +677,7 @@ func (s *testIndexSuite) TestAddIndexWithNullColumn(c *C) { handle, err := t.AddRecord(ctx, row) c.Assert(err, IsNil) - err = ctx.CommitTxn() + err = ctx.NewTxn() c.Assert(err, IsNil) checkOK := false @@ -753,13 +712,12 @@ func (s *testIndexSuite) TestAddIndexWithNullColumn(c *C) { job := testCreateIndex(c, ctx, d, s.dbInfo, tblInfo, true, "c2", "c2") testCheckJobDone(c, d, job, true) - _, err = ctx.GetTxn(true) - c.Assert(err, IsNil) + c.Assert(ctx.NewTxn(), IsNil) job = testDropTable(c, ctx, d, s.dbInfo, tblInfo) testCheckJobDone(c, d, job, false) - err = ctx.CommitTxn() + err = ctx.Txn().Commit() c.Assert(err, IsNil) d.close() diff --git a/ddl/reorg_test.go b/ddl/reorg_test.go index 7f2a42c072627..fa527a74703b0 100644 --- a/ddl/reorg_test.go +++ b/ddl/reorg_test.go @@ -48,16 +48,16 @@ func (s *testDDLSuite) TestReorg(c *C) { c.Assert(ctx.Value(testCtxKey), Equals, 1) ctx.ClearValue(testCtxKey) - txn, err := ctx.GetTxn(true) + err := ctx.NewTxn() c.Assert(err, IsNil) - txn.Set([]byte("a"), []byte("b")) - err = ctx.RollbackTxn() + ctx.Txn().Set([]byte("a"), []byte("b")) + err = ctx.Txn().Rollback() c.Assert(err, IsNil) - txn, err = ctx.GetTxn(false) + err = ctx.NewTxn() c.Assert(err, IsNil) - txn.Set([]byte("a"), []byte("b")) - err = ctx.CommitTxn() + ctx.Txn().Set([]byte("a"), []byte("b")) + err = ctx.Txn().Commit() c.Assert(err, IsNil) done := make(chan struct{}) @@ -141,7 +141,7 @@ func (s *testDDLSuite) TestReorgOwner(c *C) { c.Assert(err, IsNil) } - err := ctx.CommitTxn() + err := ctx.Txn().Commit() c.Assert(err, IsNil) tc := &testDDLCallback{} diff --git a/ddl/schema_test.go b/ddl/schema_test.go index ad96b73a4cb8b..9400485fe862f 100644 --- a/ddl/schema_test.go +++ b/ddl/schema_test.go @@ -222,6 +222,7 @@ func (s *testSchemaSuite) TestSchemaWaitJob(c *C) { func testRunInterruptedJob(c *C, d *ddl, job *model.Job) { ctx := mock.NewContext() + ctx.Store = d.store done := make(chan error, 1) go func() { done <- d.doDDLJob(ctx, job) diff --git a/ddl/stat_test.go b/ddl/stat_test.go index aee046aff5515..100482c896933 100644 --- a/ddl/stat_test.go +++ b/ddl/stat_test.go @@ -58,6 +58,7 @@ func (s *testStatSuite) TestStat(c *C) { } ctx := mock.NewContext() + ctx.Store = store done := make(chan error, 1) go func() { done <- d.doDDLJob(ctx, job) diff --git a/ddl/table_test.go b/ddl/table_test.go index 83e6e2b4f8385..16064b18d5cc7 100644 --- a/ddl/table_test.go +++ b/ddl/table_test.go @@ -173,7 +173,6 @@ func (s *testTableSuite) TestTable(c *C) { d := s.d ctx := testNewContext(c, d) - defer ctx.RollbackTxn() tblInfo := testTableInfo(c, d, "t", 3) job := testCreateTable(c, ctx, d, s.dbInfo, tblInfo) diff --git a/domain/domain_test.go b/domain/domain_test.go index f3c4a43c3c28c..8a8d87f7e75a6 100644 --- a/domain/domain_test.go +++ b/domain/domain_test.go @@ -42,11 +42,11 @@ func (*testSuite) TestT(c *C) { store, err := driver.Open("memory") c.Assert(err, IsNil) defer testleak.AfterTest(c)() - ctx := mock.NewContext() - dom, err := NewDomain(store, 80*time.Millisecond) c.Assert(err, IsNil) store = dom.Store() + ctx := mock.NewContext() + ctx.Store = store dd := dom.DDL() c.Assert(dd, NotNil) c.Assert(dd.GetLease(), Equals, 80*time.Millisecond) diff --git a/executor/adapter.go b/executor/adapter.go index 6a0ea0bd4def8..b4f30f7ea5135 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -28,6 +28,7 @@ type recordSet struct { fields []*ast.ResultField executor Executor schema expression.Schema + ctx context.Context } func (a *recordSet) Fields() ([]*ast.ResultField, error) { @@ -125,9 +126,9 @@ func (a *statement) Exec(ctx context.Context) (ast.RecordSet, error) { } } } - return &recordSet{ executor: e, schema: e.Schema(), + ctx: ctx, }, nil } diff --git a/executor/builder.go b/executor/builder.go index e0caa0765fc40..7ccb0d5f204b0 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -131,17 +131,12 @@ func (b *executorBuilder) buildShowDDL(v *plan.ShowDDL) Executor { ctx: b.ctx, schema: v.GetSchema(), } - txn, err := e.ctx.GetTxn(false) + ddlInfo, err := inspectkv.GetDDLInfo(e.ctx.Txn()) if err != nil { b.err = errors.Trace(err) return nil } - ddlInfo, err := inspectkv.GetDDLInfo(txn) - if err != nil { - b.err = errors.Trace(err) - return nil - } - bgInfo, err := inspectkv.GetBgDDLInfo(txn) + bgInfo, err := inspectkv.GetBgDDLInfo(e.ctx.Txn()) if err != nil { b.err = errors.Trace(err) return nil @@ -292,6 +287,7 @@ func (b *executorBuilder) buildLoadData(v *plan.LoadData) Executor { Table: tbl, FieldsInfo: v.FieldsInfo, LinesInfo: v.LinesInfo, + Ctx: b.ctx, }, } } @@ -489,12 +485,7 @@ func (b *executorBuilder) buildTableDual(v *plan.TableDual) Executor { func (b *executorBuilder) getStartTS() uint64 { startTS := b.ctx.GetSessionVars().SnapshotTS if startTS == 0 { - txn, err := b.ctx.GetTxn(false) - if err != nil { - b.err = errors.Trace(err) - return 0 - } - startTS = txn.StartTS() + startTS = b.ctx.Txn().StartTS() } return startTS } diff --git a/executor/executor.go b/executor/executor.go index 899621d73fb71..4831d89e9794d 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -133,7 +133,6 @@ func (e *ShowDDLExec) Next() (*Row, error) { if e.done { return nil, nil } - var ddlOwner, ddlJob string if e.ddlInfo.Owner != nil { ddlOwner = e.ddlInfo.Owner.String() @@ -198,10 +197,7 @@ func (e *CheckTableExec) Next() (*Row, error) { return nil, errors.Trace(err) } for _, idx := range tb.Indices() { - txn, err := e.ctx.GetTxn(false) - if err != nil { - return nil, errors.Trace(err) - } + txn := e.ctx.Txn() err = inspectkv.CompareIndexData(txn, tb, idx) if err != nil { return nil, errors.Errorf("%v err:%v", t.Name, err) @@ -247,10 +243,7 @@ func (e *SelectLockExec) Next() (*Row, error) { } if len(row.RowKeys) != 0 && e.Lock == ast.SelectLockForUpdate { e.ctx.GetSessionVars().TxnCtx.ForUpdate = true - txn, err := e.ctx.GetTxn(false) - if err != nil { - return nil, errors.Trace(err) - } + txn := e.ctx.Txn() for _, k := range row.RowKeys { lockKey := tablecodec.EncodeRowKeyWithHandle(k.Tbl.Meta().ID, k.Handle) err = txn.LockKeys(lockKey) diff --git a/executor/executor_ddl.go b/executor/executor_ddl.go index 804d3b6fb59db..dae7562858cbb 100644 --- a/executor/executor_ddl.go +++ b/executor/executor_ddl.go @@ -72,6 +72,11 @@ func (e *DDLExec) Next() (*Row, error) { if err != nil { return nil, errors.Trace(err) } + // Update InfoSchema in TxnCtx, so it will pass schema check. + is := sessionctx.GetDomain(e.ctx).InfoSchema() + txnCtx := e.ctx.GetSessionVars().TxnCtx + txnCtx.InfoSchema = is + txnCtx.SchemaVersion = is.SchemaMetaVersion() // DDL will force commit old transaction, after DDL, in transaction status should be false. e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) e.done = true diff --git a/executor/executor_set_test.go b/executor/executor_set_test.go index 5e77f8e0f27b6..85c155b154165 100644 --- a/executor/executor_set_test.go +++ b/executor/executor_set_test.go @@ -101,6 +101,7 @@ func (s *testSuite) TestSetVar(c *C) { vars := tk.Se.(context.Context).GetSessionVars() tk.Se.CommitTxn() tk.MustExec("set @@autocommit = 1") + c.Assert(vars.InTxn(), IsFalse) c.Assert(vars.IsAutocommit(), IsTrue) tk.MustExec("set @@autocommit = 0") c.Assert(vars.IsAutocommit(), IsFalse) diff --git a/executor/executor_simple.go b/executor/executor_simple.go index 55463a666e4df..8b7b26792552b 100644 --- a/executor/executor_simple.go +++ b/executor/executor_simple.go @@ -112,7 +112,7 @@ func (e *SimpleExec) executeUse(s *ast.UseStmt) error { } func (e *SimpleExec) executeBegin(s *ast.BeginStmt) error { - _, err := e.ctx.GetTxn(true) + err := e.ctx.NewTxn() if err != nil { return errors.Trace(err) } @@ -130,9 +130,11 @@ func (e *SimpleExec) executeCommit(s *ast.CommitStmt) { func (e *SimpleExec) executeRollback(s *ast.RollbackStmt) error { sessVars := e.ctx.GetSessionVars() log.Infof("[%d] execute rollback statement", sessVars.ConnectionID) - err := e.ctx.RollbackTxn() sessVars.SetStatusFlag(mysql.ServerStatusInTrans, false) - return errors.Trace(err) + if e.ctx.Txn().Valid() { + return e.ctx.Txn().Rollback() + } + return nil } func (e *SimpleExec) executeCreateUser(s *ast.CreateUserStmt) error { @@ -213,12 +215,12 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { failedUsers = append(failedUsers, spec.User) } } - - err := e.ctx.CommitTxn() - if err != nil { - return errors.Trace(err) - } if len(failedUsers) > 0 { + // Commit the transaction even if we returns error + err := e.ctx.Txn().Commit() + if err != nil { + return errors.Trace(err) + } errMsg := "Operation ALTER USER failed for " + strings.Join(failedUsers, ",") return terror.ClassExecutor.New(CodeCannotUser, errMsg) } @@ -245,11 +247,12 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { failedUsers = append(failedUsers, user) } } - err := e.ctx.CommitTxn() - if err != nil { - return errors.Trace(err) - } if len(failedUsers) > 0 { + // Commit the transaction even if we returns error + err := e.ctx.Txn().Commit() + if err != nil { + return errors.Trace(err) + } errMsg := "Operation DROP USER failed for " + strings.Join(failedUsers, ",") return terror.ClassExecutor.New(CodeCannotUser, errMsg) } @@ -371,10 +374,7 @@ func (e *SimpleExec) collectSamples(result ast.RecordSet) (count int64, samples } func (e *SimpleExec) buildStatisticsAndSaveToKV(tn *ast.TableName, count int64, sampleRows []*ast.Row) error { - txn, err := e.ctx.GetTxn(false) - if err != nil { - return errors.Trace(err) - } + txn := e.ctx.Txn() columnSamples := rowsToColumnSamples(sampleRows) sc := e.ctx.GetSessionVars().StmtCtx t, err := statistics.NewTable(sc, tn.TableInfo, int64(txn.StartTS()), count, defaultBucketCount, columnSamples) diff --git a/executor/executor_simple_test.go b/executor/executor_simple_test.go index 4143de39261e1..b004a93e5e5f3 100644 --- a/executor/executor_simple_test.go +++ b/executor/executor_simple_test.go @@ -70,6 +70,22 @@ func (s *testSuite) TestTransaction(c *C) { c.Assert(inTxn(ctx), IsTrue) tk.MustExec("rollback") c.Assert(inTxn(ctx), IsFalse) + + // Test that begin implicitly commits previous transaction. + tk.MustExec("use test") + tk.MustExec("create table txn (a int)") + tk.MustExec("begin") + tk.MustExec("insert txn values (1)") + tk.MustExec("begin") + tk.MustExec("rollback") + tk.MustQuery("select * from txn").Check(testkit.Rows("1")) + + // Test that DDL implicitly commits previous transaction. + tk.MustExec("begin") + tk.MustExec("insert txn values (2)") + tk.MustExec("create table txn2 (a int)") + tk.MustExec("rollback") + tk.MustQuery("select * from txn").Check(testkit.Rows("1", "2")) } func inTxn(ctx context.Context) bool { @@ -214,9 +230,9 @@ func (s *testSuite) TestAnalyzeTable(c *C) { c.Check(err, IsNil) tableID := t.Meta().ID - txn, err := ctx.GetTxn(true) + err = ctx.NewTxn() c.Check(err, IsNil) - meta := meta.NewMeta(txn) + meta := meta.NewMeta(ctx.Txn()) tpb, err := meta.GetTableStats(tableID) c.Check(err, IsNil) c.Check(tpb, NotNil) diff --git a/executor/executor_test.go b/executor/executor_test.go index ea17e4fb88373..22d7887511af1 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -167,6 +167,7 @@ type testCase struct { func checkCases(cases []testCase, ld *executor.LoadDataInfo, c *C, tk *testkit.TestKit, ctx context.Context, selectSQL, deleteSQL string) { for _, ca := range cases { + c.Assert(ctx.NewTxn(), IsNil) data, err1 := ld.InsertData(ca.data1, ca.data2) c.Assert(err1, IsNil) if ca.restData == nil { @@ -176,7 +177,7 @@ func checkCases(cases []testCase, ld *executor.LoadDataInfo, c.Assert(data, DeepEquals, ca.restData, Commentf("data1:%v, data2:%v, data:%v", string(ca.data1), string(ca.data2), string(data))) } - err1 = ctx.CommitTxn() + err1 = ctx.Txn().Commit() c.Assert(err1, IsNil) r := tk.MustQuery(selectSQL) r.Check(testkit.Rows(ca.expected...)) @@ -1400,7 +1401,7 @@ func (s *testSuite) TestAdapterStatement(c *C) { c.Check(err, IsNil) compiler := &executor.Compiler{} ctx := se.(context.Context) - tidb.PrepareTxnCtx(ctx) + c.Check(tidb.PrepareTxnCtx(ctx), IsNil) stmtNode, err := s.ParseOneStmt("select 1", "", "") c.Check(err, IsNil) diff --git a/executor/executor_write.go b/executor/executor_write.go index 66eaf4167b14f..4a2843d9ca8db 100644 --- a/executor/executor_write.go +++ b/executor/executor_write.go @@ -258,6 +258,7 @@ func NewLoadDataInfo(row []types.Datum, ctx context.Context, tbl table.Table) *L row: row, insertVal: &InsertValues{ctx: ctx, Table: tbl}, Table: tbl, + Ctx: ctx, } } @@ -270,6 +271,7 @@ type LoadDataInfo struct { Table table.Table FieldsInfo *ast.FieldsClause LinesInfo *ast.LinesClause + Ctx context.Context } // getValidData returns prevData and curData that starts from starting symbol. @@ -576,7 +578,7 @@ func (e *InsertExec) Next() (*Row, error) { if err != nil { return nil, errors.Trace(err) } - txn, err := e.ctx.GetTxn(false) + txn := e.ctx.Txn() if err != nil { return nil, errors.Trace(err) } diff --git a/inspectkv/inspectkv_test.go b/inspectkv/inspectkv_test.go index 174aa59c440f4..d867719d2800f 100644 --- a/inspectkv/inspectkv_test.go +++ b/inspectkv/inspectkv_test.go @@ -195,9 +195,10 @@ func (s *testSuite) TestScan(c *C) { tb, err := tables.TableFromMeta(alloc, s.tbInfo) c.Assert(err, IsNil) indices := tb.Indices() + c.Assert(s.ctx.NewTxn(), IsNil) _, err = tb.AddRecord(s.ctx, types.MakeDatums(1, 10, 11)) c.Assert(err, IsNil) - s.ctx.CommitTxn() + c.Assert(s.ctx.Txn().Commit(), IsNil) record1 := &RecordData{Handle: int64(1), Values: types.MakeDatums(int64(1), int64(10), int64(11))} record2 := &RecordData{Handle: int64(2), Values: types.MakeDatums(int64(2), int64(20), int64(21))} @@ -207,9 +208,10 @@ func (s *testSuite) TestScan(c *C) { c.Assert(err, IsNil) c.Assert(records, DeepEquals, []*RecordData{record1}) + c.Assert(s.ctx.NewTxn(), IsNil) _, err = tb.AddRecord(s.ctx, record2.Values) c.Assert(err, IsNil) - s.ctx.CommitTxn() + c.Assert(s.ctx.Txn().Commit(), IsNil) txn, err := s.store.Begin() c.Assert(err, IsNil) @@ -246,10 +248,12 @@ func (s *testSuite) TestScan(c *C) { s.testIndex(c, tb, tb.Indices()[0]) + c.Assert(s.ctx.NewTxn(), IsNil) err = tb.RemoveRecord(s.ctx, 1, record1.Values) c.Assert(err, IsNil) err = tb.RemoveRecord(s.ctx, 2, record2.Values) c.Assert(err, IsNil) + c.Assert(s.ctx.Txn().Commit(), IsNil) } func newDiffRetError(prefix string, ra, rb *RecordData) string { diff --git a/plan/physical_plan_builder.go b/plan/physical_plan_builder.go index 2331026ff6523..c5a20db8c1fc6 100644 --- a/plan/physical_plan_builder.go +++ b/plan/physical_plan_builder.go @@ -141,12 +141,8 @@ func (p *DataSource) convert2TableScan(prop *requiredProperty) (*physicalPlanInf ts.allocator = p.allocator ts.SetSchema(p.GetSchema()) ts.initIDAndContext(p.ctx) - txn, err := p.ctx.GetTxn(false) - if err != nil { - return nil, errors.Trace(err) - } - if txn != nil { - ts.readOnly = txn.IsReadOnly() + if p.ctx.Txn() != nil { + ts.readOnly = p.ctx.Txn().IsReadOnly() } else { ts.readOnly = true } @@ -196,6 +192,7 @@ func (p *DataSource) convert2TableScan(prop *requiredProperty) (*physicalPlanInf break } } + var err error rowCount, err = getRowCountByTableRange(sc, statsTbl, ts.Ranges, offset) if err != nil { return nil, errors.Trace(err) @@ -222,12 +219,8 @@ func (p *DataSource) convert2IndexScan(prop *requiredProperty, index *model.Inde is.allocator = p.allocator is.initIDAndContext(p.ctx) is.SetSchema(p.schema) - txn, err := p.ctx.GetTxn(false) - if err != nil { - return nil, errors.Trace(err) - } - if txn != nil { - is.readOnly = txn.IsReadOnly() + if p.ctx.Txn() != nil { + is.readOnly = p.ctx.Txn().IsReadOnly() } else { is.readOnly = true } diff --git a/server/conn.go b/server/conn.go index 66bcfb3d724b8..c372b547709cd 100644 --- a/server/conn.go +++ b/server/conn.go @@ -544,7 +544,7 @@ func (cc *clientConn) handleLoadData(loadDataInfo *executor.LoadDataInfo) error if err != nil { return errors.Trace(err) } - + loadDataInfo.Ctx.NewTxn() var prevData []byte var curData []byte var shouldBreak bool @@ -570,8 +570,7 @@ func (cc *clientConn) handleLoadData(loadDataInfo *executor.LoadDataInfo) error break } } - - return nil + return loadDataInfo.Ctx.Txn().Commit() } const queryLogMaxLen = 2048 diff --git a/session.go b/session.go index e590fef896d06..f8b426c523362 100644 --- a/session.go +++ b/session.go @@ -51,13 +51,12 @@ import ( // Session context type Session interface { - Status() uint16 // Flag of current status, such as autocommit. - LastInsertID() uint64 // Last inserted auto_increment id. - AffectedRows() uint64 // Affected rows by latest executed stmt. - SetValue(key fmt.Stringer, value interface{}) // SetValue saves a value associated with this session for key. - Value(key fmt.Stringer) interface{} // Value returns the value associated with this session for key. - Execute(sql string) ([]ast.RecordSet, error) // Execute a sql statement. - String() string // For debug + context.Context + Status() uint16 // Flag of current status, such as autocommit. + LastInsertID() uint64 // Last inserted auto_increment id. + AffectedRows() uint64 // Affected rows by latest executed stmt. + Execute(sql string) ([]ast.RecordSet, error) // Execute a sql statement. + String() string // For debug CommitTxn() error RollbackTxn() error // For execute prepare statement in binary protocol. @@ -112,12 +111,10 @@ func (h *stmtHistory) clone() *stmtHistory { const unlimitedRetryCnt = -1 type session struct { - txn kv.Transaction // current transaction - // It is the schema version in current transaction. If it's 0, the transaction is nil. - schemaVerInCurrTxn int64 - values map[fmt.Stringer]interface{} - store kv.Storage - maxRetryCnt int // Max retry times. If maxRetryCnt <=0, there is no limitation for retry times. + txn kv.Transaction // current transaction + values map[fmt.Stringer]interface{} + store kv.Storage + maxRetryCnt int // Max retry times. If maxRetryCnt <=0, there is no limitation for retry times. // For performance_schema only. stmtState *perfschema.StatementState @@ -155,22 +152,17 @@ func (s *session) checkSchemaValid() error { var ts uint64 if s.txn != nil { ts = s.txn.StartTS() - } else { - s.schemaVerInCurrTxn = 0 } - + txnSchemaVer := s.sessionVars.TxnCtx.SchemaVersion var err error var currSchemaVer int64 for i := 0; i < schemaExpiredRetryTimes; i++ { - currSchemaVer, err = sessionctx.GetDomain(s).SchemaValidity.Check(ts, s.schemaVerInCurrTxn) + currSchemaVer, err = sessionctx.GetDomain(s).SchemaValidity.Check(ts, txnSchemaVer) if err == nil { - if s.txn == nil { - s.schemaVerInCurrTxn = currSchemaVer - } return nil } log.Infof("schema version original %d, current %d, sleep time %v", - s.schemaVerInCurrTxn, currSchemaVer, checkSchemaValiditySleepTime) + txnSchemaVer, currSchemaVer, checkSchemaValiditySleepTime) if terror.ErrorEqual(err, domain.ErrInfoSchemaChanged) { break } @@ -200,7 +192,7 @@ func (s *session) SetConnectionID(connectionID uint64) { } func (s *session) doCommit() error { - if s.txn == nil { + if s.txn == nil || !s.txn.Valid() { return nil } defer func() { @@ -254,11 +246,11 @@ func (s *session) CommitTxn() error { } func (s *session) RollbackTxn() error { - if s.txn == nil { - return nil + var err error + if s.txn != nil && s.txn.Valid() { + err = s.txn.Rollback() } s.cleanRetryInfo() - err := s.txn.Rollback() s.txn = nil s.sessionVars.SetStatusFlag(mysql.ServerStatusInTrans, false) return errors.Trace(err) @@ -313,32 +305,32 @@ func (s *session) Retry() error { var err error retryCnt := 0 for { - PrepareTxnCtx(s) - success := true - s.sessionVars.RetryInfo.ResetOffset() - for _, sr := range nh.history { - st := sr.st - txt := st.OriginText() - if len(txt) > sqlLogMaxLen { - txt = txt[:sqlLogMaxLen] - } - log.Warnf("Retry %s (len:%d)", txt, len(st.OriginText())) - _, err = st.Exec(s) - if err != nil { - if s.isRetryableError(err) { - success = false + err = PrepareTxnCtx(s) + if err == nil { + s.sessionVars.RetryInfo.ResetOffset() + for _, sr := range nh.history { + st := sr.st + txt := st.OriginText() + if len(txt) > sqlLogMaxLen { + txt = txt[:sqlLogMaxLen] + } + log.Warnf("Retry %s (len:%d)", txt, len(st.OriginText())) + _, err = st.Exec(s) + if err != nil { break } - log.Warnf("session:%v, err:%v", s, err) - return errors.Trace(err) } } - if success { + if err == nil { err = s.doCommit() - if !s.isRetryableError(err) { + if err == nil { break } } + if !s.isRetryableError(err) { + log.Warnf("session:%v, err:%v", s, err) + return errors.Trace(err) + } retryCnt++ if (s.maxRetryCnt != unlimitedRetryCnt) && (retryCnt >= s.maxRetryCnt) { return errors.Trace(err) @@ -564,6 +556,7 @@ func (s *session) ExecutePreparedStmt(stmtID uint32, args ...interface{}) (ast.R } err = PrepareTxnCtx(s) if err != nil { + s.RollbackTxn() return nil, errors.Trace(err) } st := executor.CompileExecutePreparedStmt(s, stmtID, args...) @@ -617,6 +610,25 @@ func (s *session) GetTxn(forceNew bool) (kv.Transaction, error) { return s.txn, nil } +func (s *session) Txn() kv.Transaction { + return s.txn +} + +func (s *session) NewTxn() error { + if s.txn != nil && s.txn.Valid() { + err := s.doCommitWithRetry() + if err != nil { + return errors.Trace(err) + } + } + txn, err := s.store.Begin() + if err != nil { + return errors.Trace(err) + } + s.txn = txn + return nil +} + func (s *session) SetValue(key fmt.Stringer, value interface{}) { s.values[key] = value } @@ -635,7 +647,7 @@ func (s *session) Close() error { return s.RollbackTxn() } -// GetSessionVars implements the context.Context interface +// GetSessionVars implements the context.Context interface. func (s *session) GetSessionVars() *variable.SessionVars { return s.sessionVars } diff --git a/session_test.go b/session_test.go index 7e8231fab6947..e0ffc78b51e97 100644 --- a/session_test.go +++ b/session_test.go @@ -276,18 +276,16 @@ func (s *testSessionSuite) TestAutoIncrementID(c *C) { c.Assert(err, IsNil) } -func checkTxn(c *C, se Session, stmt string, expect uint16) { +func checkTxn(c *C, se Session, stmt string, expectStatus uint16) { mustExecSQL(c, se, stmt) - if expect == 0 { - c.Assert(se.(*session).txn, IsNil) - return + if expectStatus != 0 { + c.Assert(se.(*session).txn.Valid(), IsTrue) } - c.Assert(se.(*session).txn, NotNil) } -func checkAutocommit(c *C, se Session, expect uint16) { +func checkAutocommit(c *C, se Session, expectStatus uint16) { ret := se.(*session).sessionVars.Status & mysql.ServerStatusAutocommit - c.Assert(ret, Equals, expect) + c.Assert(ret, Equals, expectStatus) } // See https://dev.mysql.com/doc/internals/en/status-flags.html @@ -326,10 +324,10 @@ func (s *testSessionSuite) TestAutocommit(c *C) { c.Assert(err, IsNil) } -func checkInTrans(c *C, se Session, stmt string, expect uint16) { - checkTxn(c, se, stmt, expect) +func checkInTrans(c *C, se Session, stmt string, expectStatus uint16) { + checkTxn(c, se, stmt, expectStatus) ret := se.(*session).sessionVars.Status & mysql.ServerStatusInTrans - c.Assert(ret, Equals, expect) + c.Assert(ret, Equals, expectStatus) } // See https://dev.mysql.com/doc/internals/en/status-flags.html diff --git a/table/tables/tables.go b/table/tables/tables.go index 0cf7b7b957c40..c775c9c022ed2 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -168,19 +168,6 @@ func (t *Table) FirstKey() kv.Key { return t.RecordKey(0) } -// Truncate implements table.Table Truncate interface. -func (t *Table) Truncate(ctx context.Context) error { - txn, err := ctx.GetTxn(false) - if err != nil { - return errors.Trace(err) - } - err = util.DelKeyWithPrefix(txn, t.RecordPrefix()) - if err != nil { - return errors.Trace(err) - } - return util.DelKeyWithPrefix(txn, t.IndexPrefix()) -} - // UpdateRecord implements table.Table UpdateRecord interface. func (t *Table) UpdateRecord(ctx context.Context, h int64, oldData []types.Datum, newData []types.Datum, touched map[int]bool) error { // We should check whether this table has on update column which state is write only. @@ -193,12 +180,8 @@ func (t *Table) UpdateRecord(ctx context.Context, h int64, oldData []types.Datum return errors.Trace(err) } - txn, err := ctx.GetTxn(false) - if err != nil { - return errors.Trace(err) - } - - bs := kv.NewBufferStore(txn) + txn := ctx.Txn() + bs := kv.NewBufferStore(ctx.Txn()) // Compose new row t.composeNewData(touched, currentData, oldData) @@ -315,11 +298,7 @@ func (t *Table) AddRecord(ctx context.Context, r []types.Datum) (recordID int64, return 0, errors.Trace(err) } } - txn, err := ctx.GetTxn(false) - if err != nil { - return 0, errors.Trace(err) - } - + txn := ctx.Txn() skipCheck := ctx.GetSessionVars().SkipConstraintCheck if skipCheck { txn.SetOption(kv.SkipCheckForWrite, true) @@ -399,10 +378,7 @@ func (t *Table) genIndexKeyStr(colVals []types.Datum) (string, error) { // Add data into indices. func (t *Table) addIndices(ctx context.Context, recordID int64, r []types.Datum, bs *kv.BufferStore) (int64, error) { - txn, err := ctx.GetTxn(false) - if err != nil { - return 0, errors.Trace(err) - } + txn := ctx.Txn() // Clean up lazy check error environment defer txn.DelOption(kv.PresumeKeyNotExistsError) skipCheck := ctx.GetSessionVars().SkipConstraintCheck @@ -411,7 +387,7 @@ func (t *Table) addIndices(ctx context.Context, recordID int64, r []types.Datum, recordKey := t.RecordKey(recordID) e := kv.ErrKeyExists.FastGen("Duplicate entry '%d' for key 'PRIMARY'", recordID) txn.SetOption(kv.PresumeKeyNotExistsError, e) - _, err = txn.Get(recordKey) + _, err := txn.Get(recordKey) if err == nil { return recordID, errors.Trace(e) } else if !terror.ErrorEqual(err, kv.ErrNotExist) { @@ -451,13 +427,9 @@ func (t *Table) addIndices(ctx context.Context, recordID int64, r []types.Datum, // RowWithCols implements table.Table RowWithCols interface. func (t *Table) RowWithCols(ctx context.Context, h int64, cols []*table.Column) ([]types.Datum, error) { - txn, err := ctx.GetTxn(false) - if err != nil { - return nil, errors.Trace(err) - } // Get raw row data from kv. key := t.RecordKey(h) - value, err := txn.Get(key) + value, err := ctx.Txn().Get(key) if err != nil { return nil, errors.Trace(err) } @@ -597,12 +569,8 @@ func (t *Table) addDeleteBinlog(ctx context.Context, h int64, r []types.Datum) e } func (t *Table) removeRowData(ctx context.Context, h int64) error { - txn, err := ctx.GetTxn(false) - if err != nil { - return errors.Trace(err) - } // Remove row data. - err = txn.Delete([]byte(t.RecordKey(h))) + err := ctx.Txn().Delete([]byte(t.RecordKey(h))) if err != nil { return errors.Trace(err) } @@ -617,11 +585,7 @@ func (t *Table) removeRowIndices(ctx context.Context, h int64, rec []types.Datum // TODO: check this continue } - txn, err := ctx.GetTxn(false) - if err != nil { - return errors.Trace(err) - } - if err = v.Delete(txn, vals, h); err != nil { + if err = v.Delete(ctx.Txn(), vals, h); err != nil { if v.Meta().State != model.StatePublic && terror.ErrorEqual(err, kv.ErrNotExist) { // If the index is not in public state, we may have not created the index, // or already deleted the index, so skip ErrNotExist error. @@ -658,11 +622,7 @@ func (t *Table) buildIndexForRow(rm kv.RetrieverMutator, h int64, vals []types.D // IterRecords implements table.Table IterRecords interface. func (t *Table) IterRecords(ctx context.Context, startKey kv.Key, cols []*table.Column, fn table.RecordIterFunc) error { - txn, err := ctx.GetTxn(false) - if err != nil { - return errors.Trace(err) - } - it, err := txn.Seek(startKey) + it, err := ctx.Txn().Seek(startKey) if err != nil { return errors.Trace(err) } @@ -732,11 +692,7 @@ func (t *Table) RebaseAutoID(newBase int64, isSetStep bool) error { // Seek implements table.Table Seek interface. func (t *Table) Seek(ctx context.Context, h int64) (int64, bool, error) { seekKey := tablecodec.EncodeRowKeyWithHandle(t.ID, h) - txn, err := ctx.GetTxn(false) - if err != nil { - return 0, false, errors.Trace(err) - } - iter, err := txn.Seek(seekKey) + iter, err := ctx.Txn().Seek(seekKey) if !iter.Valid() || !iter.Key().HasPrefix(t.RecordPrefix()) { // No more records in the table, skip to the end. return 0, false, nil diff --git a/table/tables/tables_test.go b/table/tables/tables_test.go index 3c5f5d7cc85ae..fc23310430bc9 100644 --- a/table/tables/tables_test.go +++ b/table/tables/tables_test.go @@ -57,6 +57,7 @@ func (ts *testSuite) TestBasic(c *C) { _, err := ts.se.Execute("CREATE TABLE test.t (a int primary key auto_increment, b varchar(255) unique)") c.Assert(err, IsNil) ctx := ts.se.(context.Context) + c.Assert(ctx.NewTxn(), IsNil) dom := sessionctx.GetDomain(ctx) tb, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) @@ -117,18 +118,13 @@ func (ts *testSuite) TestBasic(c *C) { _, err = tb.AddRecord(ctx, types.MakeDatums(1, "abc")) c.Assert(err, IsNil) c.Assert(indexCnt(), Greater, 0) - c.Assert(ctx.CommitTxn(), IsNil) _, err = ts.se.Execute("drop table test.t") c.Assert(err, IsNil) } func countEntriesWithPrefix(ctx context.Context, prefix []byte) (int, error) { - txn, err := ctx.GetTxn(false) - if err != nil { - return 0, err - } cnt := 0 - err = util.ScanMetaWithPrefix(txn, prefix, func(k kv.Key, v []byte) bool { + err := util.ScanMetaWithPrefix(ctx.Txn(), prefix, func(k kv.Key, v []byte) bool { cnt++ return true }) @@ -198,12 +194,12 @@ func (ts *testSuite) TestUniqueIndexMultipleNullEntries(c *C) { autoid, err := tb.AllocAutoID() c.Assert(err, IsNil) c.Assert(autoid, Greater, int64(0)) - + c.Assert(ctx.NewTxn(), IsNil) _, err = tb.AddRecord(ctx, types.MakeDatums(1, nil)) c.Assert(err, IsNil) _, err = tb.AddRecord(ctx, types.MakeDatums(2, nil)) c.Assert(err, IsNil) - ctx.RollbackTxn() + c.Assert(ctx.Txn().Rollback(), IsNil) _, err = ts.se.Execute("drop table test.t") c.Assert(err, IsNil) } @@ -257,13 +253,14 @@ func (ts *testSuite) TestUnsignedPK(c *C) { dom := sessionctx.GetDomain(ctx) tb, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("tPK")) c.Assert(err, IsNil) - + c.Assert(ctx.NewTxn(), IsNil) rid, err := tb.AddRecord(ctx, types.MakeDatums(1, "abc")) c.Assert(err, IsNil) row, err := tb.Row(ctx, rid) c.Assert(err, IsNil) c.Assert(len(row), Equals, 2) c.Assert(row[0].Kind(), Equals, types.KindUint64) + c.Assert(ctx.Txn().Commit(), IsNil) } func (ts *testSuite) TestIterRecords(c *C) { @@ -273,6 +270,7 @@ func (ts *testSuite) TestIterRecords(c *C) { _, err = ts.se.Execute("INSERT test.tIter VALUES (1, 2), (2, NULL)") c.Assert(err, IsNil) ctx := ts.se.(context.Context) + c.Assert(ctx.NewTxn(), IsNil) dom := sessionctx.GetDomain(ctx) tb, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("tIter")) c.Assert(err, IsNil) @@ -284,4 +282,5 @@ func (ts *testSuite) TestIterRecords(c *C) { }) c.Assert(err, IsNil) c.Assert(totalCount, Equals, 2) + c.Assert(ctx.Txn().Commit(), IsNil) } diff --git a/tidb.go b/tidb.go index 0ba3be0170be7..e4fd8c93c4345 100644 --- a/tidb.go +++ b/tidb.go @@ -185,16 +185,16 @@ func PrepareTxnCtx(ctx context.Context) error { func runStmt(ctx context.Context, s ast.Statement) (ast.RecordSet, error) { var err error var rs ast.RecordSet + se := ctx.(*session) rs, err = s.Exec(ctx) // All the history should be added here. - se := ctx.(*session) getHistory(ctx).add(0, s) if !se.sessionVars.InTxn() { if err != nil { log.Info("RollbackTxn for ddl/autocommit error.") - ctx.RollbackTxn() + se.RollbackTxn() } else { - err = ctx.CommitTxn() + err = se.CommitTxn() } } return rs, errors.Trace(err) diff --git a/util/mock/context.go b/util/mock/context.go index e55456675df9e..ae31f43072300 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -58,52 +58,9 @@ func (c *Context) GetSessionVars() *variable.SessionVars { return c.sessionVars } -// GetTxn implements context.Context GetTxn interface. -func (c *Context) GetTxn(forceNew bool) (kv.Transaction, error) { - c.mux.Lock() - defer c.mux.Unlock() - if c.Store == nil { - return nil, nil - } - - var err error - if c.txn == nil { - c.txn, err = c.Store.Begin() - return c.txn, err - } - if forceNew { - err = c.CommitTxn() - if err != nil { - return nil, errors.Trace(err) - } - c.txn, err = c.Store.Begin() - return c.txn, err - } - - return c.txn, nil -} - -func (c *Context) finishTxn(rollback bool) error { - if c.txn == nil { - return nil - } - defer func() { c.txn = nil }() - - if rollback { - return c.txn.Rollback() - } - - return c.txn.Commit() -} - -// CommitTxn implements context.Context CommitTxn interface. -func (c *Context) CommitTxn() error { - return c.finishTxn(false) -} - -// RollbackTxn implements context.Context RollbackTxn interface. -func (c *Context) RollbackTxn() error { - return c.finishTxn(true) +// Txn implements context.Context Txn interface. +func (c *Context) Txn() kv.Transaction { + return c.txn } // GetClient implements context.Context GetClient interface. @@ -133,6 +90,25 @@ func (c *Context) SetGlobalSysVar(ctx context.Context, name string, value string return nil } +// NewTxn implements the context.Context interface. +func (c *Context) NewTxn() error { + if c.Store == nil { + return errors.New("store is not set") + } + if c.txn != nil && c.txn.Valid() { + err := c.txn.Commit() + if err != nil { + return errors.Trace(err) + } + } + txn, err := c.Store.Begin() + if err != nil { + return errors.Trace(err) + } + c.txn = txn + return nil +} + // NewContext creates a new mocked context.Context. func NewContext() *Context { return &Context{ diff --git a/util/mock/mock_test.go b/util/mock/mock_test.go index d3decbbcff8db..b5708e8d639ae 100644 --- a/util/mock/mock_test.go +++ b/util/mock/mock_test.go @@ -49,10 +49,4 @@ func (s *testMockSuite) TestContext(c *C) { ctx.ClearValue(contextKey) v = ctx.Value(contextKey) c.Assert(v, IsNil) - - _, err := ctx.GetTxn(false) - c.Assert(err, IsNil) - - err = ctx.CommitTxn() - c.Assert(err, IsNil) }