Skip to content

Commit

Permalink
Add BeginTx for parity with sql.DB.BeginTx (go-gorm#2227)
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerstillwater authored and jinzhu committed Jun 11, 2019
1 parent cf9b85e commit fec06da
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
6 changes: 5 additions & 1 deletion interface.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package gorm

import "database/sql"
import (
"context"
"database/sql"
)

// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
Expand All @@ -12,6 +15,7 @@ type SQLCommon interface {

type sqlDb interface {
Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}

type sqlTx interface {
Expand Down
10 changes: 8 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gorm

import (
"context"
"database/sql"
"errors"
"fmt"
Expand Down Expand Up @@ -503,11 +504,16 @@ func (s *DB) Debug() *DB {
return s.clone().LogMode(true)
}

// Begin begin a transaction
// Begin begins a transaction
func (s *DB) Begin() *DB {
return s.BeginTx(context.Background(), &sql.TxOptions{})
}

// BeginTX begins a transaction with options
func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
c := s.clone()
if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, opts)
c.db = interface{}(tx).(SQLCommon)

c.dialect.SetDB(c.db)
Expand Down
35 changes: 35 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gorm_test

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
Expand Down Expand Up @@ -471,6 +472,40 @@ func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) {
}
}

func TestTransactionReadonly(t *testing.T) {
dialect := os.Getenv("GORM_DIALECT")
if dialect == "" {
dialect = "sqlite"
}
switch dialect {
case "mssql", "sqlite":
t.Skipf("%s does not support readonly transactions\n", dialect)
}

tx := DB.Begin()
u := User{Name: "transcation"}
if err := tx.Save(&u).Error; err != nil {
t.Errorf("No error should raise")
}
tx.Commit()

tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
t.Errorf("Should find saved record")
}

if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil {
t.Errorf("Should return the underlying sql.Tx")
}

u = User{Name: "transcation-2"}
if err := tx.Save(&u).Error; err == nil {
t.Errorf("Error should have been raised in a readonly transaction")
}

tx.Rollback()
}

func TestRow(t *testing.T) {
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}
Expand Down

0 comments on commit fec06da

Please sign in to comment.