From 0d96e8398578792ff83abababb5bf909dca57436 Mon Sep 17 00:00:00 2001 From: Pavlo Golub Date: Tue, 1 Jun 2021 14:33:24 +0300 Subject: [PATCH] [+] add ExpectBeginTx() and BeginTxFunc() method --- expectations.go | 41 +++++++++++++++++++++-------------------- pgxmock.go | 42 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 59 insertions(+), 24 deletions(-) diff --git a/expectations.go b/expectations.go index beec4d6..f027a27 100644 --- a/expectations.go +++ b/expectations.go @@ -31,13 +31,13 @@ func (e *commonExpectation) fulfilled() bool { return e.triggered } -// ExpectedClose is used to manage *sql.DB.Close expectation -// returned by *Sqlmock.ExpectClose. +// ExpectedClose is used to manage pgx.Close expectation +// returned by pgxmock.ExpectClose. type ExpectedClose struct { commonExpectation } -// WillReturnError allows to set an error for *sql.DB.Close action +// WillReturnError allows to set an error for pgx.Close action func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose { e.err = err return e @@ -52,14 +52,15 @@ func (e *ExpectedClose) String() string { return msg } -// ExpectedBegin is used to manage *sql.DB.Begin expectation -// returned by *Sqlmock.ExpectBegin. +// ExpectedBegin is used to manage *pgx.Begin expectation +// returned by pgxmock.ExpectBegin. type ExpectedBegin struct { commonExpectation delay time.Duration + opts pgx.TxOptions } -// WillReturnError allows to set an error for *sql.DB.Begin action +// WillReturnError allows to set an error for pgx.Begin action func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin { e.err = err return e @@ -81,13 +82,13 @@ func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin { return e } -// ExpectedCommit is used to manage *sql.Tx.Commit expectation -// returned by *Sqlmock.ExpectCommit. +// ExpectedCommit is used to manage pgx.Tx.Commit expectation +// returned by pgxmock.ExpectCommit. type ExpectedCommit struct { commonExpectation } -// WillReturnError allows to set an error for *sql.Tx.Close action +// WillReturnError allows to set an error for pgx.Tx.Close action func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit { e.err = err return e @@ -102,13 +103,13 @@ func (e *ExpectedCommit) String() string { return msg } -// ExpectedRollback is used to manage *sql.Tx.Rollback expectation -// returned by *Sqlmock.ExpectRollback. +// ExpectedRollback is used to manage pgx.Tx.Rollback expectation +// returned by pgxmock.ExpectRollback. type ExpectedRollback struct { commonExpectation } -// WillReturnError allows to set an error for *sql.Tx.Rollback action +// WillReturnError allows to set an error for pgx.Tx.Rollback action func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback { e.err = err return e @@ -125,7 +126,7 @@ func (e *ExpectedRollback) String() string { // ExpectedQuery is used to manage *pgx.Conn.Query, *pgx.Conn.QueryRow, *pgx.Tx.Query, // *pgx.Tx.QueryRow, *pgx.Stmt.Query or *pgx.Stmt.QueryRow expectations. -// Returned by *Sqlmock.ExpectQuery. +// Returned by pgxmock.ExpectQuery. type ExpectedQuery struct { queryBasedExpectation rows pgx.Rows @@ -187,8 +188,8 @@ func (e *ExpectedQuery) String() string { return msg } -// ExpectedExec is used to manage *sql.DB.Exec, *sql.Tx.Exec or *sql.Stmt.Exec expectations. -// Returned by *Sqlmock.ExpectExec. +// ExpectedExec is used to manage pgx.Exec, pgx.Tx.Exec or pgx.Stmt.Exec expectations. +// Returned by pgxmock.ExpectExec. type ExpectedExec struct { queryBasedExpectation result pgconn.CommandTag @@ -253,8 +254,8 @@ func (e *ExpectedExec) WillReturnResult(result pgconn.CommandTag) *ExpectedExec return e } -// ExpectedPrepare is used to manage *sql.DB.Prepare or *sql.Tx.Prepare expectations. -// Returned by *Sqlmock.ExpectPrepare. +// ExpectedPrepare is used to manage pgx.Prepare or pgx.Tx.Prepare expectations. +// Returned by pgxmock.ExpectPrepare. type ExpectedPrepare struct { commonExpectation mock *pgxmock @@ -266,7 +267,7 @@ type ExpectedPrepare struct { delay time.Duration } -// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action. +// WillReturnError allows to set an error for the expected pgx.Prepare or pgx.Tx.Prepare action. func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare { e.err = err return e @@ -338,8 +339,8 @@ type queryBasedExpectation struct { args []interface{} } -// ExpectedPing is used to manage *sql.DB.Ping expectations. -// Returned by *Sqlmock.ExpectPing. +// ExpectedPing is used to manage pgx.Ping expectations. +// Returned by pgxmock.ExpectPing. type ExpectedPing struct { commonExpectation delay time.Duration diff --git a/pgxmock.go b/pgxmock.go index 8f87119..d9fb886 100644 --- a/pgxmock.go +++ b/pgxmock.go @@ -198,6 +198,12 @@ func (c *pgxmock) ExpectBegin() *ExpectedBegin { return e } +func (c *pgxmock) ExpectBeginTx(txOptions pgx.TxOptions) *ExpectedBegin { + e := &ExpectedBegin{opts: txOptions} + c.expected = append(c.expected, e) + return e +} + func (c *pgxmock) ExpectExec(expectedSQL string) *ExpectedExec { e := &ExpectedExec{} e.expectSQL = expectedSQL @@ -444,8 +450,30 @@ func (c *pgxmock) BeginFunc(ctx context.Context, f func(pgx.Tx) error) (err erro return savepoint.Commit(ctx) } -func (c *pgxmock) Begin(ctx context.Context) (pgx.Tx, error) { - ex, err := c.begin() +func (c *pgxmock) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) (err error) { + var tx pgx.Tx + tx, err = c.BeginTx(ctx, txOptions) + if err != nil { + return err + } + defer func() { + rollbackErr := tx.Rollback(ctx) + if !(rollbackErr == nil || errors.Is(rollbackErr, pgx.ErrTxClosed)) { + err = rollbackErr + } + }() + + fErr := f(tx) + if fErr != nil { + _ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return + return fErr + } + + return tx.Commit(ctx) +} + +func (c *pgxmock) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { + ex, err := c.begin(txOptions) if ex != nil { time.Sleep(ex.delay) } @@ -456,7 +484,11 @@ func (c *pgxmock) Begin(ctx context.Context) (pgx.Tx, error) { return c, nil } -func (c *pgxmock) begin() (*ExpectedBegin, error) { +func (c *pgxmock) Begin(ctx context.Context) (pgx.Tx, error) { + return c.BeginTx(ctx, pgx.TxOptions{}) +} + +func (c *pgxmock) begin(txOptions pgx.TxOptions) (*ExpectedBegin, error) { var expected *ExpectedBegin var ok bool var fulfilled int @@ -484,7 +516,9 @@ func (c *pgxmock) begin() (*ExpectedBegin, error) { } return nil, fmt.Errorf(msg) } - + if expected.opts != txOptions { + return nil, fmt.Errorf("Begin: call with transaction options '%v' was not expected, expected name is '%v'", txOptions, expected.opts) + } expected.triggered = true expected.Unlock()