Skip to content

Commit

Permalink
[+] add ExpectBeginTx() and BeginTxFunc() method
Browse files Browse the repository at this point in the history
  • Loading branch information
pashagolub committed Jun 1, 2021
1 parent 8a9c8c2 commit 0d96e83
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 24 deletions.
41 changes: 21 additions & 20 deletions expectations.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ func (e *commonExpectation) fulfilled() bool {
return e.triggered
}

// ExpectedClose is used to manage *sql.DB.Close expectation
// returned by *Sqlmock.ExpectClose.
// ExpectedClose is used to manage pgx.Close expectation
// returned by pgxmock.ExpectClose.
type ExpectedClose struct {
commonExpectation
}

// WillReturnError allows to set an error for *sql.DB.Close action
// WillReturnError allows to set an error for pgx.Close action
func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose {
e.err = err
return e
Expand All @@ -52,14 +52,15 @@ func (e *ExpectedClose) String() string {
return msg
}

// ExpectedBegin is used to manage *sql.DB.Begin expectation
// returned by *Sqlmock.ExpectBegin.
// ExpectedBegin is used to manage *pgx.Begin expectation
// returned by pgxmock.ExpectBegin.
type ExpectedBegin struct {
commonExpectation
delay time.Duration
opts pgx.TxOptions
}

// WillReturnError allows to set an error for *sql.DB.Begin action
// WillReturnError allows to set an error for pgx.Begin action
func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin {
e.err = err
return e
Expand All @@ -81,13 +82,13 @@ func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin {
return e
}

// ExpectedCommit is used to manage *sql.Tx.Commit expectation
// returned by *Sqlmock.ExpectCommit.
// ExpectedCommit is used to manage pgx.Tx.Commit expectation
// returned by pgxmock.ExpectCommit.
type ExpectedCommit struct {
commonExpectation
}

// WillReturnError allows to set an error for *sql.Tx.Close action
// WillReturnError allows to set an error for pgx.Tx.Close action
func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit {
e.err = err
return e
Expand All @@ -102,13 +103,13 @@ func (e *ExpectedCommit) String() string {
return msg
}

// ExpectedRollback is used to manage *sql.Tx.Rollback expectation
// returned by *Sqlmock.ExpectRollback.
// ExpectedRollback is used to manage pgx.Tx.Rollback expectation
// returned by pgxmock.ExpectRollback.
type ExpectedRollback struct {
commonExpectation
}

// WillReturnError allows to set an error for *sql.Tx.Rollback action
// WillReturnError allows to set an error for pgx.Tx.Rollback action
func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback {
e.err = err
return e
Expand All @@ -125,7 +126,7 @@ func (e *ExpectedRollback) String() string {

// ExpectedQuery is used to manage *pgx.Conn.Query, *pgx.Conn.QueryRow, *pgx.Tx.Query,
// *pgx.Tx.QueryRow, *pgx.Stmt.Query or *pgx.Stmt.QueryRow expectations.
// Returned by *Sqlmock.ExpectQuery.
// Returned by pgxmock.ExpectQuery.
type ExpectedQuery struct {
queryBasedExpectation
rows pgx.Rows
Expand Down Expand Up @@ -187,8 +188,8 @@ func (e *ExpectedQuery) String() string {
return msg
}

// ExpectedExec is used to manage *sql.DB.Exec, *sql.Tx.Exec or *sql.Stmt.Exec expectations.
// Returned by *Sqlmock.ExpectExec.
// ExpectedExec is used to manage pgx.Exec, pgx.Tx.Exec or pgx.Stmt.Exec expectations.
// Returned by pgxmock.ExpectExec.
type ExpectedExec struct {
queryBasedExpectation
result pgconn.CommandTag
Expand Down Expand Up @@ -253,8 +254,8 @@ func (e *ExpectedExec) WillReturnResult(result pgconn.CommandTag) *ExpectedExec
return e
}

// ExpectedPrepare is used to manage *sql.DB.Prepare or *sql.Tx.Prepare expectations.
// Returned by *Sqlmock.ExpectPrepare.
// ExpectedPrepare is used to manage pgx.Prepare or pgx.Tx.Prepare expectations.
// Returned by pgxmock.ExpectPrepare.
type ExpectedPrepare struct {
commonExpectation
mock *pgxmock
Expand All @@ -266,7 +267,7 @@ type ExpectedPrepare struct {
delay time.Duration
}

// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action.
// WillReturnError allows to set an error for the expected pgx.Prepare or pgx.Tx.Prepare action.
func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare {
e.err = err
return e
Expand Down Expand Up @@ -338,8 +339,8 @@ type queryBasedExpectation struct {
args []interface{}
}

// ExpectedPing is used to manage *sql.DB.Ping expectations.
// Returned by *Sqlmock.ExpectPing.
// ExpectedPing is used to manage pgx.Ping expectations.
// Returned by pgxmock.ExpectPing.
type ExpectedPing struct {
commonExpectation
delay time.Duration
Expand Down
42 changes: 38 additions & 4 deletions pgxmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ func (c *pgxmock) ExpectBegin() *ExpectedBegin {
return e
}

func (c *pgxmock) ExpectBeginTx(txOptions pgx.TxOptions) *ExpectedBegin {
e := &ExpectedBegin{opts: txOptions}
c.expected = append(c.expected, e)
return e
}

func (c *pgxmock) ExpectExec(expectedSQL string) *ExpectedExec {
e := &ExpectedExec{}
e.expectSQL = expectedSQL
Expand Down Expand Up @@ -444,8 +450,30 @@ func (c *pgxmock) BeginFunc(ctx context.Context, f func(pgx.Tx) error) (err erro
return savepoint.Commit(ctx)
}

func (c *pgxmock) Begin(ctx context.Context) (pgx.Tx, error) {
ex, err := c.begin()
func (c *pgxmock) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) (err error) {
var tx pgx.Tx
tx, err = c.BeginTx(ctx, txOptions)
if err != nil {
return err
}
defer func() {
rollbackErr := tx.Rollback(ctx)
if !(rollbackErr == nil || errors.Is(rollbackErr, pgx.ErrTxClosed)) {
err = rollbackErr
}
}()

fErr := f(tx)
if fErr != nil {
_ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return
return fErr
}

return tx.Commit(ctx)
}

func (c *pgxmock) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) {
ex, err := c.begin(txOptions)
if ex != nil {
time.Sleep(ex.delay)
}
Expand All @@ -456,7 +484,11 @@ func (c *pgxmock) Begin(ctx context.Context) (pgx.Tx, error) {
return c, nil
}

func (c *pgxmock) begin() (*ExpectedBegin, error) {
func (c *pgxmock) Begin(ctx context.Context) (pgx.Tx, error) {
return c.BeginTx(ctx, pgx.TxOptions{})
}

func (c *pgxmock) begin(txOptions pgx.TxOptions) (*ExpectedBegin, error) {
var expected *ExpectedBegin
var ok bool
var fulfilled int
Expand Down Expand Up @@ -484,7 +516,9 @@ func (c *pgxmock) begin() (*ExpectedBegin, error) {
}
return nil, fmt.Errorf(msg)
}

if expected.opts != txOptions {
return nil, fmt.Errorf("Begin: call with transaction options '%v' was not expected, expected name is '%v'", txOptions, expected.opts)
}
expected.triggered = true
expected.Unlock()

Expand Down

0 comments on commit 0d96e83

Please sign in to comment.