diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..6ee063c --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses) + +Copyright (c) 2013-2019, DATA-DOG team +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* The name DataDog.lt may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..d4151f4 --- /dev/null +++ b/README.md @@ -0,0 +1,265 @@ +[![Build Status](https://travis-ci.org/DATA-DOG/go-sqlmock.svg)](https://travis-ci.org/DATA-DOG/go-sqlmock) +[![GoDoc](https://godoc.org/github.com/DATA-DOG/go-sqlmock?status.svg)](https://godoc.org/github.com/DATA-DOG/go-sqlmock) +[![Go Report Card](https://goreportcard.com/badge/github.com/DATA-DOG/go-sqlmock)](https://goreportcard.com/report/github.com/DATA-DOG/go-sqlmock) +[![codecov.io](https://codecov.io/github/DATA-DOG/go-sqlmock/branch/master/graph/badge.svg)](https://codecov.io/github/DATA-DOG/go-sqlmock) + +# Sql driver mock for Golang + +**sqlmock** is a mock library implementing [sql/driver](https://godoc.org/database/sql/driver). Which has one and only +purpose - to simulate any **sql** driver behavior in tests, without needing a real database connection. It helps to +maintain correct **TDD** workflow. + +- this library is now complete and stable. (you may not find new changes for this reason) +- supports concurrency and multiple connections. +- supports **go1.8** Context related feature mocking and Named sql parameters. +- does not require any modifications to your source code. +- the driver allows to mock any sql driver method behavior. +- has strict by default expectation order matching. +- has no third party dependencies. + +**NOTE:** in **v1.2.0** **sqlmock.Rows** has changed to struct from interface, if you were using any type references to that +interface, you will need to switch it to a pointer struct type. Also, **sqlmock.Rows** were used to implement **driver.Rows** +interface, which was not required or useful for mocking and was removed. Hope it will not cause issues. + +## Looking for maintainers + +I do not have much spare time for this library and willing to transfer the repository ownership +to person or an organization motivated to maintain it. Open up a conversation if you are interested. See #230. + +## Install + + go get github.com/DATA-DOG/go-sqlmock + +## Documentation and Examples + +Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference. +See **.travis.yml** for supported **go** versions. +Different use case, is to functionally test with a real database - [go-txdb](https://github.com/DATA-DOG/go-txdb) +all database related actions are isolated within a single transaction so the database can remain in the same state. + +See implementation examples: + +- [blog API server](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/blog) +- [the same orders example](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/orders) + +### Something you may want to test, assuming you use the [go-mysql-driver](https://github.com/go-sql-driver/mysql) + +``` go +package main + +import ( + "database/sql" + + _ "github.com/go-sql-driver/mysql" +) + +func recordStats(db *sql.DB, userID, productID int64) (err error) { + tx, err = db.Begin() + if err != nil { + return + } + + defer func() { + switch err { + case nil: + err = tx.Commit() + default: + tx.Rollback() + } + }() + + if _, err = tx.Exec("UPDATE products SET views = views + 1"); err != nil { + return + } + if _, err = tx.Exec("INSERT INTO product_viewers (user_id, product_id) VALUES (?, ?)", userID, productID); err != nil { + return + } + return +} + +func main() { + // @NOTE: the real connection is not required for tests + db, err := sql.Open("mysql", "root@/blog") + if err != nil { + panic(err) + } + defer db.Close() + + if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil { + panic(err) + } +} +``` + +### Tests with sqlmock + +``` go +package main + +import ( + "fmt" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +// a successful case +func TestShouldUpdateStats(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin() + mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT INTO product_viewers").WithArgs(2, 3).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + // now we execute our method + if err = recordStats(db, 2, 3); err != nil { + t.Errorf("error was not expected while updating stats: %s", err) + } + + // we make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +// a failing test case +func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin() + mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT INTO product_viewers"). + WithArgs(2, 3). + WillReturnError(fmt.Errorf("some error")) + mock.ExpectRollback() + + // now we execute our method + if err = recordStats(db, 2, 3); err == nil { + t.Errorf("was expecting an error, but there was none") + } + + // we make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} +``` + +## Customize SQL query matching + +There were plenty of requests from users regarding SQL query string validation or different matching option. +We have now implemented the `QueryMatcher` interface, which can be passed through an option when calling +`sqlmock.New` or `sqlmock.NewWithDSN`. + +This now allows to include some library, which would allow for example to parse and validate `mysql` SQL AST. +And create a custom QueryMatcher in order to validate SQL in sophisticated ways. + +By default, **sqlmock** is preserving backward compatibility and default query matcher is `sqlmock.QueryMatcherRegexp` +which uses expected SQL string as a regular expression to match incoming query string. There is an equality matcher: +`QueryMatcherEqual` which will do a full case sensitive match. + +In order to customize the QueryMatcher, use the following: + +``` go + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) +``` + +The query matcher can be fully customized based on user needs. **sqlmock** will not +provide a standard sql parsing matchers, since various drivers may not follow the same SQL standard. + +## Matching arguments like time.Time + +There may be arguments which are of `struct` type and cannot be compared easily by value like `time.Time`. In this case +**sqlmock** provides an [Argument](https://godoc.org/github.com/DATA-DOG/go-sqlmock#Argument) interface which +can be used in more sophisticated matching. Here is a simple example of time argument matching: + +``` go +type AnyTime struct{} + +// Match satisfies sqlmock.Argument interface +func (a AnyTime) Match(v driver.Value) bool { + _, ok := v.(time.Time) + return ok +} + +func TestAnyTimeArgument(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("INSERT INTO users"). + WithArgs("john", AnyTime{}). + WillReturnResult(NewResult(1, 1)) + + _, err = db.Exec("INSERT INTO users(name, created_at) VALUES (?, ?)", "john", time.Now()) + if err != nil { + t.Errorf("error '%s' was not expected, while inserting a row", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} +``` + +It only asserts that argument is of `time.Time` type. + +## Run tests + + go test -race + +## Change Log + +- **2019-04-06** - added functionality to mock a sql MetaData request +- **2019-02-13** - added `go.mod` removed the references and suggestions using `gopkg.in`. +- **2018-12-11** - added expectation of Rows to be closed, while mocking expected query. +- **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching. +- **2017-09-01** - it is now possible to expect that prepared statement will be closed, + using **ExpectedPrepare.WillBeClosed**. +- **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct + but contains all methods as before and should maintain backwards compatibility. **ExpectedQuery.WillReturnRows** may now + accept multiple row sets. +- **2016-11-02** - `db.Prepare()` was not validating expected prepare SQL + query. It should still be validated even if Exec or Query is not + executed on that prepared statement. +- **2016-02-23** - added **sqlmock.AnyArg()** function to provide any kind + of argument matcher. +- **2016-02-23** - convert expected arguments to driver.Value as natural + driver does, the change may affect time.Time comparison and will be + stricter. See [issue](https://github.com/DATA-DOG/go-sqlmock/issues/31). +- **2015-08-27** - **v1** api change, concurrency support, all known issues fixed. +- **2014-08-16** instead of **panic** during reflect type mismatch when comparing query arguments - now return error +- **2014-08-14** added **sqlmock.NewErrorResult** which gives an option to return driver.Result with errors for +interface methods, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/5) +- **2014-05-29** allow to match arguments in more sophisticated ways, by providing an **sqlmock.Argument** interface +- **2014-04-21** introduce **sqlmock.New()** to open a mock database connection for tests. This method +calls sql.DB.Ping to ensure that connection is open, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/4). +This way on Close it will surely assert if all expectations are met, even if database was not triggered at all. +The old way is still available, but it is advisable to call db.Ping manually before asserting with db.Close. +- **2014-02-14** RowsFromCSVString is now a part of Rows interface named as FromCSVString. +It has changed to allow more ways to construct rows and to easily extend this API in future. +See [issue 1](https://github.com/DATA-DOG/go-sqlmock/issues/1) +**RowsFromCSVString** is deprecated and will be removed in future + +## Contributions + +Feel free to open a pull request. Note, if you wish to contribute an extension to public (exported methods or types) - +please open an issue before, to discuss whether these changes can be accepted. All backward incompatible changes are +and will be treated cautiously + +## License + +The [three clause BSD license](http://en.wikipedia.org/wiki/BSD_licenses) + diff --git a/argument.go b/argument.go new file mode 100644 index 0000000..7727481 --- /dev/null +++ b/argument.go @@ -0,0 +1,24 @@ +package sqlmock + +import "database/sql/driver" + +// Argument interface allows to match +// any argument in specific way when used with +// ExpectedQuery and ExpectedExec expectations. +type Argument interface { + Match(driver.Value) bool +} + +// AnyArg will return an Argument which can +// match any kind of arguments. +// +// Useful for time.Time or similar kinds of arguments. +func AnyArg() Argument { + return anyArgument{} +} + +type anyArgument struct{} + +func (a anyArgument) Match(_ driver.Value) bool { + return true +} diff --git a/argument_test.go b/argument_test.go new file mode 100644 index 0000000..0e0d13b --- /dev/null +++ b/argument_test.go @@ -0,0 +1,58 @@ +package sqlmock + +import ( + "database/sql/driver" + "testing" + "time" +) + +type AnyTime struct{} + +// Match satisfies sqlmock.Argument interface +func (a AnyTime) Match(v driver.Value) bool { + _, ok := v.(time.Time) + return ok +} + +func TestAnyTimeArgument(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("INSERT INTO users"). + WithArgs("john", AnyTime{}). + WillReturnResult(NewResult(1, 1)) + + _, err = db.Exec("INSERT INTO users(name, created_at) VALUES (?, ?)", "john", time.Now()) + if err != nil { + t.Errorf("error '%s' was not expected, while inserting a row", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestByteSliceArgument(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + username := []byte("user") + mock.ExpectExec("INSERT INTO users").WithArgs(username).WillReturnResult(NewResult(1, 1)) + + _, err = db.Exec("INSERT INTO users(username) VALUES (?)", username) + if err != nil { + t.Errorf("error '%s' was not expected, while inserting a row", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/column.go b/column.go new file mode 100644 index 0000000..e418d2e --- /dev/null +++ b/column.go @@ -0,0 +1,77 @@ +package sqlmock + +import "reflect" + +// Column is a mocked column Metadata for rows.ColumnTypes() +type Column struct { + name string + dbType string + nullable bool + nullableOk bool + length int64 + lengthOk bool + precision int64 + scale int64 + psOk bool + scanType reflect.Type +} + +func (c *Column) Name() string { + return c.name +} + +func (c *Column) DbType() string { + return c.dbType +} + +func (c *Column) IsNullable() (bool, bool) { + return c.nullable, c.nullableOk +} + +func (c *Column) Length() (int64, bool) { + return c.length, c.lengthOk +} + +func (c *Column) PrecisionScale() (int64, int64, bool) { + return c.precision, c.scale, c.psOk +} + +func (c *Column) ScanType() reflect.Type { + return c.scanType +} + +// NewColumn returns a Column with specified name +func NewColumn(name string) *Column { + return &Column{ + name: name, + } +} + +// Nullable returns the column with nullable metadata set +func (c *Column) Nullable(nullable bool) *Column { + c.nullable = nullable + c.nullableOk = true + return c +} + +// OfType returns the column with type metadata set +func (c *Column) OfType(dbType string, sampleValue interface{}) *Column { + c.dbType = dbType + c.scanType = reflect.TypeOf(sampleValue) + return c +} + +// WithLength returns the column with length metadata set. +func (c *Column) WithLength(length int64) *Column { + c.length = length + c.lengthOk = true + return c +} + +// WithPrecisionAndScale returns the column with precision and scale metadata set. +func (c *Column) WithPrecisionAndScale(precision, scale int64) *Column { + c.precision = precision + c.scale = scale + c.psOk = true + return c +} diff --git a/column_test.go b/column_test.go new file mode 100644 index 0000000..0311216 --- /dev/null +++ b/column_test.go @@ -0,0 +1,63 @@ +package sqlmock + +import ( + "reflect" + "testing" + "time" +) + +func TestColumn(t *testing.T) { + now, _ := time.Parse(time.RFC3339, "2020-06-20T22:08:41Z") + column1 := NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100) + column2 := NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4) + column3 := NewColumn("when").OfType("TIMESTAMP", now) + + if column1.ScanType().Kind() != reflect.String { + t.Errorf("string scanType mismatch: %v", column1.ScanType()) + } + if column2.ScanType().Kind() != reflect.Float64 { + t.Errorf("float scanType mismatch: %v", column2.ScanType()) + } + if column3.ScanType() != reflect.TypeOf(time.Time{}) { + t.Errorf("time scanType mismatch: %v", column3.ScanType()) + } + + nullable, ok := column1.IsNullable() + if !nullable || !ok { + t.Errorf("'test' column should be nullable") + } + nullable, ok = column2.IsNullable() + if nullable || !ok { + t.Errorf("'number' column should not be nullable") + } + nullable, ok = column3.IsNullable() + if ok { + t.Errorf("'when' column nullability should be unknown") + } + + length, ok := column1.Length() + if length != 100 || !ok { + t.Errorf("'test' column wrong length") + } + length, ok = column2.Length() + if ok { + t.Errorf("'number' column is not of variable length type") + } + length, ok = column3.Length() + if ok { + t.Errorf("'when' column is not of variable length type") + } + + _, _, ok = column1.PrecisionScale() + if ok { + t.Errorf("'test' column not applicable") + } + precision, scale, ok := column2.PrecisionScale() + if precision != 10 || scale != 4 || !ok { + t.Errorf("'number' column not applicable") + } + _, _, ok = column3.PrecisionScale() + if ok { + t.Errorf("'when' column not applicable") + } +} diff --git a/driver.go b/driver.go new file mode 100644 index 0000000..802f8fb --- /dev/null +++ b/driver.go @@ -0,0 +1,81 @@ +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "sync" +) + +var pool *mockDriver + +func init() { + pool = &mockDriver{ + conns: make(map[string]*sqlmock), + } + sql.Register("sqlmock", pool) +} + +type mockDriver struct { + sync.Mutex + counter int + conns map[string]*sqlmock +} + +func (d *mockDriver) Open(dsn string) (driver.Conn, error) { + d.Lock() + defer d.Unlock() + + c, ok := d.conns[dsn] + if !ok { + return c, fmt.Errorf("expected a connection to be available, but it is not") + } + + c.opened++ + return c, nil +} + +// New creates sqlmock database connection and a mock to manage expectations. +// Accepts options, like ValueConverterOption, to use a ValueConverter from +// a specific driver. +// Pings db so that all expectations could be +// asserted. +func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) { + pool.Lock() + dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter) + pool.counter++ + + smock := &sqlmock{dsn: dsn, drv: pool, ordered: true} + pool.conns[dsn] = smock + pool.Unlock() + + return smock.open(options) +} + +// NewWithDSN creates sqlmock database connection with a specific DSN +// and a mock to manage expectations. +// Accepts options, like ValueConverterOption, to use a ValueConverter from +// a specific driver. +// Pings db so that all expectations could be asserted. +// +// This method is introduced because of sql abstraction +// libraries, which do not provide a way to initialize +// with sql.DB instance. For example GORM library. +// +// Note, it will error if attempted to create with an +// already used dsn +// +// It is not recommended to use this method, unless you +// really need it and there is no other way around. +func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) { + pool.Lock() + if _, ok := pool.conns[dsn]; ok { + pool.Unlock() + return nil, nil, fmt.Errorf("cannot create a new mock database with the same dsn: %s", dsn) + } + smock := &sqlmock{dsn: dsn, drv: pool, ordered: true} + pool.conns[dsn] = smock + pool.Unlock() + + return smock.open(options) +} diff --git a/driver_test.go b/driver_test.go new file mode 100644 index 0000000..bbd7293 --- /dev/null +++ b/driver_test.go @@ -0,0 +1,132 @@ +package sqlmock + +import ( + "database/sql/driver" + "errors" + "fmt" + "testing" +) + +type void struct{} + +func (void) Print(...interface{}) {} + +type converter struct{} + +func (c *converter) ConvertValue(v interface{}) (driver.Value, error) { + return nil, errors.New("converter disabled") +} + +func ExampleNew() { + db, mock, err := New() + if err != nil { + fmt.Println("expected no error, but got:", err) + return + } + defer db.Close() + // now we can expect operations performed on db + mock.ExpectBegin().WillReturnError(fmt.Errorf("an error will occur on db.Begin() call")) +} + +func TestShouldOpenConnectionIssue15(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("expected no error, but got: %s", err) + } + if len(pool.conns) != 1 { + t.Errorf("expected 1 connection in pool, but there is: %d", len(pool.conns)) + } + + smock, _ := mock.(*sqlmock) + if smock.opened != 1 { + t.Errorf("expected 1 connection on mock to be opened, but there is: %d", smock.opened) + } + + // defer so the rows gets closed first + defer func() { + if smock.opened != 0 { + t.Errorf("expected no connections on mock to be opened, but there is: %d", smock.opened) + } + }() + + mock.ExpectQuery("SELECT").WillReturnRows(NewRows([]string{"one", "two"}).AddRow("val1", "val2")) + rows, err := db.Query("SELECT") + if err != nil { + t.Errorf("unexpected error: %s", err) + } + defer rows.Close() + + mock.ExpectExec("UPDATE").WillReturnResult(NewResult(1, 1)) + if _, err = db.Exec("UPDATE"); err != nil { + t.Errorf("unexpected error: %s", err) + } + + // now there should be two connections open + if smock.opened != 2 { + t.Errorf("expected 2 connection on mock to be opened, but there is: %d", smock.opened) + } + + mock.ExpectClose() + if err = db.Close(); err != nil { + t.Errorf("expected no error on close, but got: %s", err) + } + + // one is still reserved for rows + if smock.opened != 1 { + t.Errorf("expected 1 connection on mock to be still reserved for rows, but there is: %d", smock.opened) + } +} + +func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("expected no error, but got: %s", err) + } + db2, mock2, err := New() + if err != nil { + t.Errorf("expected no error, but got: %s", err) + } + if len(pool.conns) != 2 { + t.Errorf("expected 2 connection in pool, but there is: %d", len(pool.conns)) + } + + if db == db2 { + t.Errorf("expected not the same database instance, but it is the same") + } + if mock == mock2 { + t.Errorf("expected not the same mock instance, but it is the same") + } +} + +func TestWithOptions(t *testing.T) { + c := &converter{} + _, mock, err := New(ValueConverterOption(c)) + if err != nil { + t.Errorf("expected no error, but got: %s", err) + } + smock, _ := mock.(*sqlmock) + if smock.converter.(*converter) != c { + t.Errorf("expected a custom converter to be set") + } +} + +func TestWrongDSN(t *testing.T) { + t.Parallel() + db, _, _ := New() + defer db.Close() + if _, err := db.Driver().Open("wrong_dsn"); err == nil { + t.Error("expected error on Open") + } +} + +func TestNewDSN(t *testing.T) { + if _, _, err := NewWithDSN("sqlmock_db_99"); err != nil { + t.Errorf("expected no error on NewWithDSN, but got: %s", err) + } +} + +func TestDuplicateNewDSN(t *testing.T) { + if _, _, err := NewWithDSN("sqlmock_db_1"); err == nil { + t.Error("expected error on NewWithDSN") + } +} diff --git a/examples/.vscode/launch.json b/examples/.vscode/launch.json new file mode 100644 index 0000000..5c7247b --- /dev/null +++ b/examples/.vscode/launch.json @@ -0,0 +1,7 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [] +} \ No newline at end of file diff --git a/examples/basic/basic.go b/examples/basic/basic.go new file mode 100644 index 0000000..0fbf98d --- /dev/null +++ b/examples/basic/basic.go @@ -0,0 +1,40 @@ +package main + +import "database/sql" + +func recordStats(db *sql.DB, userID, productID int64) (err error) { + tx, err := db.Begin() + if err != nil { + return + } + + defer func() { + switch err { + case nil: + err = tx.Commit() + default: + tx.Rollback() + } + }() + + if _, err = tx.Exec("UPDATE products SET views = views + 1"); err != nil { + return + } + if _, err = tx.Exec("INSERT INTO product_viewers (user_id, product_id) VALUES (?, ?)", userID, productID); err != nil { + return + } + return +} + +func main() { + // @NOTE: the real connection is not required for tests + db, err := sql.Open("mysql", "root@/blog") + if err != nil { + panic(err) + } + defer db.Close() + + if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil { + panic(err) + } +} diff --git a/examples/basic/basic_test.go b/examples/basic/basic_test.go new file mode 100644 index 0000000..e9153a5 --- /dev/null +++ b/examples/basic/basic_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "fmt" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +// a successful case +func TestShouldUpdateStats(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin() + mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT INTO product_viewers").WithArgs(2, 3).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + // now we execute our method + if err = recordStats(db, 2, 3); err != nil { + t.Errorf("error was not expected while updating stats: %s", err) + } + + // we make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +// a failing test case +func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin() + mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT INTO product_viewers"). + WithArgs(2, 3). + WillReturnError(fmt.Errorf("some error")) + mock.ExpectRollback() + + // now we execute our method + if err = recordStats(db, 2, 3); err == nil { + t.Errorf("was expecting an error, but there was none") + } + + // we make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/examples/blog/blog.go b/examples/blog/blog.go new file mode 100644 index 0000000..c4aec06 --- /dev/null +++ b/examples/blog/blog.go @@ -0,0 +1,81 @@ +package main + +import ( + "database/sql" + "encoding/json" + "net/http" +) + +type api struct { + db *sql.DB +} + +type post struct { + ID int + Title string + Body string +} + +func (a *api) posts(w http.ResponseWriter, r *http.Request) { + rows, err := a.db.Query("SELECT id, title, body FROM posts") + if err != nil { + a.fail(w, "failed to fetch posts: "+err.Error(), 500) + return + } + defer rows.Close() + + var posts []*post + for rows.Next() { + p := &post{} + if err := rows.Scan(&p.ID, &p.Title, &p.Body); err != nil { + a.fail(w, "failed to scan post: "+err.Error(), 500) + return + } + posts = append(posts, p) + } + if rows.Err() != nil { + a.fail(w, "failed to read all posts: "+rows.Err().Error(), 500) + return + } + + data := struct { + Posts []*post + }{posts} + + a.ok(w, data) +} + +func main() { + // @NOTE: the real connection is not required for tests + db, err := sql.Open("mysql", "root@/blog") + if err != nil { + panic(err) + } + app := &api{db: db} + http.HandleFunc("/posts", app.posts) + http.ListenAndServe(":8080", nil) +} + +func (a *api) fail(w http.ResponseWriter, msg string, status int) { + w.Header().Set("Content-Type", "application/json") + + data := struct { + Error string + }{Error: msg} + + resp, _ := json.Marshal(data) + w.WriteHeader(status) + w.Write(resp) +} + +func (a *api) ok(w http.ResponseWriter, data interface{}) { + w.Header().Set("Content-Type", "application/json") + + resp, err := json.Marshal(data) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + a.fail(w, "oops something evil has happened", 500) + return + } + w.Write(resp) +} diff --git a/examples/blog/blog_test.go b/examples/blog/blog_test.go new file mode 100644 index 0000000..2442067 --- /dev/null +++ b/examples/blog/blog_test.go @@ -0,0 +1,102 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +func (a *api) assertJSON(actual []byte, data interface{}, t *testing.T) { + expected, err := json.Marshal(data) + if err != nil { + t.Fatalf("an error '%s' was not expected when marshaling expected json data", err) + } + + if bytes.Compare(expected, actual) != 0 { + t.Errorf("the expected json: %s is different from actual %s", expected, actual) + } +} + +func TestShouldGetPosts(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // create app with mocked db, request and response to test + app := &api{db} + req, err := http.NewRequest("GET", "http://localhost/posts", nil) + if err != nil { + t.Fatalf("an error '%s' was not expected while creating request", err) + } + w := httptest.NewRecorder() + + // before we actually execute our api function, we need to expect required DB actions + rows := sqlmock.NewRows([]string{"id", "title", "body"}). + AddRow(1, "post 1", "hello"). + AddRow(2, "post 2", "world") + + mock.ExpectQuery("^SELECT (.+) FROM posts$").WillReturnRows(rows) + + // now we execute our request + app.posts(w, req) + + if w.Code != 200 { + t.Fatalf("expected status code to be 200, but got: %d", w.Code) + } + + data := struct { + Posts []*post + }{Posts: []*post{ + {ID: 1, Title: "post 1", Body: "hello"}, + {ID: 2, Title: "post 2", Body: "world"}, + }} + app.assertJSON(w.Body.Bytes(), data, t) + + // we make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestShouldRespondWithErrorOnFailure(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // create app with mocked db, request and response to test + app := &api{db} + req, err := http.NewRequest("GET", "http://localhost/posts", nil) + if err != nil { + t.Fatalf("an error '%s' was not expected while creating request", err) + } + w := httptest.NewRecorder() + + // before we actually execute our api function, we need to expect required DB actions + mock.ExpectQuery("^SELECT (.+) FROM posts$").WillReturnError(fmt.Errorf("some error")) + + // now we execute our request + app.posts(w, req) + + if w.Code != 500 { + t.Fatalf("expected status code to be 500, but got: %d", w.Code) + } + + data := struct { + Error string + }{"failed to fetch posts: some error"} + app.assertJSON(w.Body.Bytes(), data, t) + + // we make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/examples/doc.go b/examples/doc.go new file mode 100644 index 0000000..c7842af --- /dev/null +++ b/examples/doc.go @@ -0,0 +1 @@ +package examples diff --git a/examples/orders/orders.go b/examples/orders/orders.go new file mode 100644 index 0000000..fb7e47e --- /dev/null +++ b/examples/orders/orders.go @@ -0,0 +1,121 @@ +package main + +import ( + "database/sql" + "fmt" + "log" + + "github.com/kisielk/sqlstruct" +) + +const ORDER_PENDING = 0 +const ORDER_CANCELLED = 1 + +type User struct { + Id int `sql:"id"` + Username string `sql:"username"` + Balance float64 `sql:"balance"` +} + +type Order struct { + Id int `sql:"id"` + Value float64 `sql:"value"` + ReservedFee float64 `sql:"reserved_fee"` + Status int `sql:"status"` +} + +func cancelOrder(id int, db *sql.DB) (err error) { + tx, err := db.Begin() + if err != nil { + return + } + + var order Order + var user User + sql := fmt.Sprintf(` +SELECT %s, %s +FROM orders AS o +INNER JOIN users AS u ON o.buyer_id = u.id +WHERE o.id = ? +FOR UPDATE`, + sqlstruct.ColumnsAliased(order, "o"), + sqlstruct.ColumnsAliased(user, "u")) + + // fetch order to cancel + rows, err := tx.Query(sql, id) + if err != nil { + tx.Rollback() + return + } + + defer rows.Close() + // no rows, nothing to do + if !rows.Next() { + tx.Rollback() + return + } + + // read order + err = sqlstruct.ScanAliased(&order, rows, "o") + if err != nil { + tx.Rollback() + return + } + + // ensure order status + if order.Status != ORDER_PENDING { + tx.Rollback() + return + } + + // read user + err = sqlstruct.ScanAliased(&user, rows, "u") + if err != nil { + tx.Rollback() + return + } + rows.Close() // manually close before other prepared statements + + // refund order value + sql = "UPDATE users SET balance = balance + ? WHERE id = ?" + refundStmt, err := tx.Prepare(sql) + if err != nil { + tx.Rollback() + return + } + defer refundStmt.Close() + _, err = refundStmt.Exec(order.Value+order.ReservedFee, user.Id) + if err != nil { + tx.Rollback() + return + } + + // update order status + order.Status = ORDER_CANCELLED + sql = "UPDATE orders SET status = ?, updated = NOW() WHERE id = ?" + orderUpdStmt, err := tx.Prepare(sql) + if err != nil { + tx.Rollback() + return + } + defer orderUpdStmt.Close() + _, err = orderUpdStmt.Exec(order.Status, order.Id) + if err != nil { + tx.Rollback() + return + } + return tx.Commit() +} + +func main() { + // @NOTE: the real connection is not required for tests + db, err := sql.Open("mysql", "root:@/orders") + if err != nil { + log.Fatal(err) + } + defer db.Close() + err = cancelOrder(1, db) + if err != nil { + log.Fatal(err) + } +} diff --git a/examples/orders/orders_test.go b/examples/orders/orders_test.go new file mode 100644 index 0000000..1dd10b1 --- /dev/null +++ b/examples/orders/orders_test.go @@ -0,0 +1,108 @@ +package main + +import ( + "fmt" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +// will test that order with a different status, cannot be cancelled +func TestShouldNotCancelOrderWithNonPendingStatus(t *testing.T) { + // open database stub + db, mock, err := sqlmock.New() + if err != nil { + t.Errorf("An error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // columns are prefixed with "o" since we used sqlstruct to generate them + columns := []string{"o_id", "o_status"} + // expect transaction begin + mock.ExpectBegin() + // expect query to fetch order and user, match it with regexp + mock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows(columns).FromCSVString("1,1")) + // expect transaction rollback, since order status is "cancelled" + mock.ExpectRollback() + + // run the cancel order function + err = cancelOrder(1, db) + if err != nil { + t.Errorf("Expected no error, but got %s instead", err) + } + // we make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +// will test order cancellation +func TestShouldRefundUserWhenOrderIsCancelled(t *testing.T) { + // open database stub + db, mock, err := sqlmock.New() + if err != nil { + t.Errorf("An error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // columns are prefixed with "o" since we used sqlstruct to generate them + columns := []string{"o_id", "o_status", "o_value", "o_reserved_fee", "u_id", "u_balance"} + // expect transaction begin + mock.ExpectBegin() + // expect query to fetch order and user, match it with regexp + mock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows(columns).AddRow(1, 0, 25.75, 3.25, 2, 10.00)) + // expect user balance update + mock.ExpectPrepare("UPDATE users SET balance").ExpectExec(). + WithArgs(25.75+3.25, 2). // refund amount, user id + WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row + // expect order status update + mock.ExpectPrepare("UPDATE orders SET status").ExpectExec(). + WithArgs(ORDER_CANCELLED, 1). // status, id + WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row + // expect a transaction commit + mock.ExpectCommit() + + // run the cancel order function + err = cancelOrder(1, db) + if err != nil { + t.Errorf("Expected no error, but got %s instead", err) + } + // we make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +// will test order cancellation +func TestShouldRollbackOnError(t *testing.T) { + // open database stub + db, mock, err := sqlmock.New() + if err != nil { + t.Errorf("An error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // expect transaction begin + mock.ExpectBegin() + // expect query to fetch order and user, match it with regexp + mock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE"). + WithArgs(1). + WillReturnError(fmt.Errorf("Some error")) + // should rollback since error was returned from query execution + mock.ExpectRollback() + + // run the cancel order function + err = cancelOrder(1, db) + // error should return back + if err == nil { + t.Error("Expected error, but got none") + } + // we make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/expectations.go b/expectations.go new file mode 100644 index 0000000..5c82c7b --- /dev/null +++ b/expectations.go @@ -0,0 +1,369 @@ +package sqlmock + +import ( + "database/sql/driver" + "fmt" + "strings" + "sync" + "time" +) + +// an expectation interface +type expectation interface { + fulfilled() bool + Lock() + Unlock() + String() string +} + +// common expectation struct +// satisfies the expectation interface +type commonExpectation struct { + sync.Mutex + triggered bool + err error +} + +func (e *commonExpectation) fulfilled() bool { + return e.triggered +} + +// ExpectedClose is used to manage *sql.DB.Close expectation +// returned by *Sqlmock.ExpectClose. +type ExpectedClose struct { + commonExpectation +} + +// WillReturnError allows to set an error for *sql.DB.Close action +func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose { + e.err = err + return e +} + +// String returns string representation +func (e *ExpectedClose) String() string { + msg := "ExpectedClose => expecting database Close" + if e.err != nil { + msg += fmt.Sprintf(", which should return error: %s", e.err) + } + return msg +} + +// ExpectedBegin is used to manage *sql.DB.Begin expectation +// returned by *Sqlmock.ExpectBegin. +type ExpectedBegin struct { + commonExpectation + delay time.Duration +} + +// WillReturnError allows to set an error for *sql.DB.Begin action +func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin { + e.err = err + return e +} + +// String returns string representation +func (e *ExpectedBegin) String() string { + msg := "ExpectedBegin => expecting database transaction Begin" + if e.err != nil { + msg += fmt.Sprintf(", which should return error: %s", e.err) + } + return msg +} + +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin { + e.delay = duration + return e +} + +// ExpectedCommit is used to manage *sql.Tx.Commit expectation +// returned by *Sqlmock.ExpectCommit. +type ExpectedCommit struct { + commonExpectation +} + +// WillReturnError allows to set an error for *sql.Tx.Close action +func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit { + e.err = err + return e +} + +// String returns string representation +func (e *ExpectedCommit) String() string { + msg := "ExpectedCommit => expecting transaction Commit" + if e.err != nil { + msg += fmt.Sprintf(", which should return error: %s", e.err) + } + return msg +} + +// ExpectedRollback is used to manage *sql.Tx.Rollback expectation +// returned by *Sqlmock.ExpectRollback. +type ExpectedRollback struct { + commonExpectation +} + +// WillReturnError allows to set an error for *sql.Tx.Rollback action +func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback { + e.err = err + return e +} + +// String returns string representation +func (e *ExpectedRollback) String() string { + msg := "ExpectedRollback => expecting transaction Rollback" + if e.err != nil { + msg += fmt.Sprintf(", which should return error: %s", e.err) + } + return msg +} + +// ExpectedQuery is used to manage *sql.DB.Query, *dql.DB.QueryRow, *sql.Tx.Query, +// *sql.Tx.QueryRow, *sql.Stmt.Query or *sql.Stmt.QueryRow expectations. +// Returned by *Sqlmock.ExpectQuery. +type ExpectedQuery struct { + queryBasedExpectation + rows driver.Rows + delay time.Duration + rowsMustBeClosed bool + rowsWereClosed bool +} + +// WithArgs will match given expected args to actual database query arguments. +// if at least one argument does not match, it will return an error. For specific +// arguments an sqlmock.Argument interface can be used to match an argument. +func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery { + e.args = args + return e +} + +// RowsWillBeClosed expects this query rows to be closed. +func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery { + e.rowsMustBeClosed = true + return e +} + +// WillReturnError allows to set an error for expected database query +func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery { + e.err = err + return e +} + +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery { + e.delay = duration + return e +} + +// String returns string representation +func (e *ExpectedQuery) String() string { + msg := "ExpectedQuery => expecting Query, QueryContext or QueryRow which:" + msg += "\n - matches sql: '" + e.expectSQL + "'" + + if len(e.args) == 0 { + msg += "\n - is without arguments" + } else { + msg += "\n - is with arguments:\n" + for i, arg := range e.args { + msg += fmt.Sprintf(" %d - %+v\n", i, arg) + } + msg = strings.TrimSpace(msg) + } + + if e.rows != nil { + msg += fmt.Sprintf("\n - %s", e.rows) + } + + if e.err != nil { + msg += fmt.Sprintf("\n - should return error: %s", e.err) + } + + return msg +} + +// ExpectedExec is used to manage *sql.DB.Exec, *sql.Tx.Exec or *sql.Stmt.Exec expectations. +// Returned by *Sqlmock.ExpectExec. +type ExpectedExec struct { + queryBasedExpectation + result driver.Result + delay time.Duration +} + +// WithArgs will match given expected args to actual database exec operation arguments. +// if at least one argument does not match, it will return an error. For specific +// arguments an sqlmock.Argument interface can be used to match an argument. +func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec { + e.args = args + return e +} + +// WillReturnError allows to set an error for expected database exec action +func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec { + e.err = err + return e +} + +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec { + e.delay = duration + return e +} + +// String returns string representation +func (e *ExpectedExec) String() string { + msg := "ExpectedExec => expecting Exec or ExecContext which:" + msg += "\n - matches sql: '" + e.expectSQL + "'" + + if len(e.args) == 0 { + msg += "\n - is without arguments" + } else { + msg += "\n - is with arguments:\n" + var margs []string + for i, arg := range e.args { + margs = append(margs, fmt.Sprintf(" %d - %+v", i, arg)) + } + msg += strings.Join(margs, "\n") + } + + if e.result != nil { + res, _ := e.result.(*result) + msg += "\n - should return Result having:" + msg += fmt.Sprintf("\n LastInsertId: %d", res.insertID) + msg += fmt.Sprintf("\n RowsAffected: %d", res.rowsAffected) + if res.err != nil { + msg += fmt.Sprintf("\n Error: %s", res.err) + } + } + + if e.err != nil { + msg += fmt.Sprintf("\n - should return error: %s", e.err) + } + + return msg +} + +// WillReturnResult arranges for an expected Exec() to return a particular +// result, there is sqlmock.NewResult(lastInsertID int64, affectedRows int64) method +// to build a corresponding result. Or if actions needs to be tested against errors +// sqlmock.NewErrorResult(err error) to return a given error. +func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec { + e.result = result + return e +} + +// ExpectedPrepare is used to manage *sql.DB.Prepare or *sql.Tx.Prepare expectations. +// Returned by *Sqlmock.ExpectPrepare. +type ExpectedPrepare struct { + commonExpectation + mock *sqlmock + expectSQL string + statement driver.Stmt + closeErr error + mustBeClosed bool + wasClosed bool + delay time.Duration +} + +// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action. +func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare { + e.err = err + return e +} + +// WillReturnCloseError allows to set an error for this prepared statement Close action +func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare { + e.closeErr = err + return e +} + +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare { + e.delay = duration + return e +} + +// WillBeClosed expects this prepared statement to +// be closed. +func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare { + e.mustBeClosed = true + return e +} + +// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement. +// This method is convenient in order to prevent duplicating sql query string matching. +func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { + eq := &ExpectedQuery{} + eq.expectSQL = e.expectSQL + eq.converter = e.mock.converter + e.mock.expected = append(e.mock.expected, eq) + return eq +} + +// ExpectExec allows to expect Exec() on this prepared statement. +// This method is convenient in order to prevent duplicating sql query string matching. +func (e *ExpectedPrepare) ExpectExec() *ExpectedExec { + eq := &ExpectedExec{} + eq.expectSQL = e.expectSQL + eq.converter = e.mock.converter + e.mock.expected = append(e.mock.expected, eq) + return eq +} + +// String returns string representation +func (e *ExpectedPrepare) String() string { + msg := "ExpectedPrepare => expecting Prepare statement which:" + msg += "\n - matches sql: '" + e.expectSQL + "'" + + if e.err != nil { + msg += fmt.Sprintf("\n - should return error: %s", e.err) + } + + if e.closeErr != nil { + msg += fmt.Sprintf("\n - should return error on Close: %s", e.closeErr) + } + + return msg +} + +// query based expectation +// adds a query matching logic +type queryBasedExpectation struct { + commonExpectation + expectSQL string + converter driver.ValueConverter + args []driver.Value +} + +// ExpectedPing is used to manage *sql.DB.Ping expectations. +// Returned by *Sqlmock.ExpectPing. +type ExpectedPing struct { + commonExpectation + delay time.Duration +} + +// WillDelayFor allows to specify duration for which it will delay result. May +// be used together with Context. +func (e *ExpectedPing) WillDelayFor(duration time.Duration) *ExpectedPing { + e.delay = duration + return e +} + +// WillReturnError allows to set an error for expected database ping +func (e *ExpectedPing) WillReturnError(err error) *ExpectedPing { + e.err = err + return e +} + +// String returns string representation +func (e *ExpectedPing) String() string { + msg := "ExpectedPing => expecting database Ping" + if e.err != nil { + msg += fmt.Sprintf(", which should return error: %s", e.err) + } + return msg +} diff --git a/expectations_go18.go b/expectations_go18.go new file mode 100644 index 0000000..6b85ce1 --- /dev/null +++ b/expectations_go18.go @@ -0,0 +1,85 @@ +// +build go1.8 + +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" +) + +// WillReturnRows specifies the set of resulting rows that will be returned +// by the triggered query +func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { + defs := 0 + sets := make([]*Rows, len(rows)) + for i, r := range rows { + sets[i] = r + if r.def != nil { + defs++ + } + } + if defs > 0 && defs == len(sets) { + e.rows = &rowSetsWithDefinition{&rowSets{sets: sets, ex: e}} + } else { + e.rows = &rowSets{sets: sets, ex: e} + } + return e +} + +func (e *queryBasedExpectation) argsMatches(args []driver.NamedValue) error { + if nil == e.args { + return nil + } + if len(args) != len(e.args) { + return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) + } + // @TODO should we assert either all args are named or ordinal? + for k, v := range args { + // custom argument matcher + matcher, ok := e.args[k].(Argument) + if ok { + if !matcher.Match(v.Value) { + return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) + } + continue + } + + dval := e.args[k] + if named, isNamed := dval.(sql.NamedArg); isNamed { + dval = named.Value + if v.Name != named.Name { + return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name) + } + } else if k+1 != v.Ordinal { + return fmt.Errorf("argument %d: ordinal position: %d does not match expected: %d", k, k+1, v.Ordinal) + } + + // convert to driver converter + darg, err := e.converter.ConvertValue(dval) + if err != nil { + return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) + } + + if !reflect.DeepEqual(darg, v.Value) { + return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) + } + } + return nil +} + +func (e *queryBasedExpectation) attemptArgMatch(args []driver.NamedValue) (err error) { + // catch panic + defer func() { + if e := recover(); e != nil { + _, ok := e.(error) + if !ok { + err = fmt.Errorf(e.(string)) + } + } + }() + + err = e.argsMatches(args) + return +} diff --git a/expectations_go18_test.go b/expectations_go18_test.go new file mode 100644 index 0000000..1974721 --- /dev/null +++ b/expectations_go18_test.go @@ -0,0 +1,174 @@ +// +build go1.8 + +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "testing" + "time" +) + +func TestQueryExpectationArgComparison(t *testing.T) { + e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} + against := []driver.NamedValue{{Value: int64(5), Ordinal: 1}} + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) + } + + e.args = []driver.Value{5, "str"} + + against = []driver.NamedValue{{Value: int64(5), Ordinal: 1}} + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the size is not the same") + } + + against = []driver.NamedValue{ + {Value: int64(3), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the first argument (int value) is different") + } + + against = []driver.NamedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "st", Ordinal: 2}, + } + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the second argument (string value) is different") + } + + against = []driver.NamedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, but it did not: %s", err) + } + + const longForm = "Jan 2, 2006 at 3:04pm (MST)" + tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") + e.args = []driver.Value{5, tm} + + against = []driver.NamedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: tm, Ordinal: 2}, + } + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, but it did not") + } + + e.args = []driver.Value{5, AnyArg()} + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, but it did not: %s", err) + } +} + +func TestQueryExpectationArgComparisonBool(t *testing.T) { + var e *queryBasedExpectation + + e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter} + against := []driver.NamedValue{ + {Value: true, Ordinal: 1}, + } + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, since arguments are the same") + } + + e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter} + against = []driver.NamedValue{ + {Value: false, Ordinal: 1}, + } + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, since argument are the same") + } + + e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter} + against = []driver.NamedValue{ + {Value: false, Ordinal: 1}, + } + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since argument is different") + } + + e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter} + against = []driver.NamedValue{ + {Value: true, Ordinal: 1}, + } + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since argument is different") + } +} + +func TestQueryExpectationNamedArgComparison(t *testing.T) { + e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} + against := []driver.NamedValue{{Value: int64(5), Name: "id"}} + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) + } + + e.args = []driver.Value{ + sql.Named("id", 5), + sql.Named("s", "str"), + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the size is not the same") + } + + against = []driver.NamedValue{ + {Value: int64(5), Name: "id"}, + {Value: "str", Name: "s"}, + } + + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should have matched, but it did not: %v", err) + } + + against = []driver.NamedValue{ + {Value: int64(5), Name: "id"}, + {Value: "str", Name: "username"}, + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments matched, but it should have not due to Name") + } + + e.args = []driver.Value{int64(5), "str"} + + against = []driver.NamedValue{ + {Value: int64(5), Ordinal: 0}, + {Value: "str", Ordinal: 1}, + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments matched, but it should have not due to wrong Ordinal position") + } + + against = []driver.NamedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } + + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should have matched, but it did not: %v", err) + } +} + +type panicConverter struct { +} + +func (s panicConverter) ConvertValue(v interface{}) (driver.Value, error) { + panic(v) +} + +func Test_queryBasedExpectation_attemptArgMatch(t *testing.T) { + e := &queryBasedExpectation{converter: new(panicConverter), args: []driver.Value{"test"}} + values := []driver.NamedValue{ + {Ordinal: 1, Name: "test", Value: "test"}, + } + if err := e.attemptArgMatch(values); err == nil { + t.Errorf("error expected") + } +} diff --git a/expectations_go19_test.go b/expectations_go19_test.go new file mode 100644 index 0000000..4ea5f04 --- /dev/null +++ b/expectations_go19_test.go @@ -0,0 +1,43 @@ +// +build go1.9 + +package sqlmock + +import ( + "context" + "testing" +) + +func TestCustomValueConverterExec(t *testing.T) { + db, mock, _ := New(ValueConverterOption(CustomConverter{})) + expectedQuery := "INSERT INTO tags \\(name,email,age,hobbies\\) VALUES \\(\\?,\\?,\\?,\\?\\)" + query := "INSERT INTO tags (name,email,age,hobbies) VALUES (?,?,?,?)" + name := "John" + email := "j@jj.j" + age := 12 + hobbies := []string{"soccer", "netflix"} + mock.ExpectBegin() + mock.ExpectPrepare(expectedQuery) + mock.ExpectExec(expectedQuery).WithArgs(name, email, age, hobbies).WillReturnResult(NewResult(1, 1)) + mock.ExpectCommit() + + ctx := context.Background() + tx, e := db.BeginTx(ctx, nil) + if e != nil { + t.Error(e) + return + } + stmt, e := db.PrepareContext(ctx, query) + if e != nil { + t.Error(e) + return + } + _, e = stmt.Exec(name, email, age, hobbies) + if e != nil { + t.Error(e) + return + } + tx.Commit() + if err := mock.ExpectationsWereMet(); err != nil { + t.Error(err) + } +} diff --git a/expectations_test.go b/expectations_test.go new file mode 100644 index 0000000..afda582 --- /dev/null +++ b/expectations_test.go @@ -0,0 +1,103 @@ +package sqlmock + +import ( + "database/sql/driver" + "errors" + "fmt" + "reflect" + "testing" +) + +type CustomConverter struct{} + +func (s CustomConverter) ConvertValue(v interface{}) (driver.Value, error) { + switch v.(type) { + case string: + return v.(string), nil + case []string: + return v.([]string), nil + case int: + return v.(int), nil + default: + return nil, errors.New(fmt.Sprintf("cannot convert %T with value %v", v, v)) + } +} + +func ExampleExpectedExec() { + db, mock, _ := New() + result := NewErrorResult(fmt.Errorf("some error")) + mock.ExpectExec("^INSERT (.+)").WillReturnResult(result) + res, _ := db.Exec("INSERT something") + _, err := res.LastInsertId() + fmt.Println(err) + // Output: some error +} + +func TestBuildQuery(t *testing.T) { + db, mock, _ := New() + query := ` + SELECT + name, + email, + address, + anotherfield + FROM user + where + name = 'John' + and + address = 'Jakarta' + + ` + + mock.ExpectQuery(query) + mock.ExpectExec(query) + mock.ExpectPrepare(query) + + db.QueryRow(query) + db.Exec(query) + db.Prepare(query) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Error(err) + } +} + +func TestCustomValueConverterQueryScan(t *testing.T) { + db, mock, _ := New(ValueConverterOption(CustomConverter{})) + query := ` + SELECT + name, + email, + address, + anotherfield + FROM user + where + name = 'John' + and + address = 'Jakarta' + + ` + expectedStringValue := "ValueOne" + expectedIntValue := 2 + expectedArrayValue := []string{"Three", "Four"} + mock.ExpectQuery(query).WillReturnRows(mock.NewRows([]string{"One", "Two", "Three"}).AddRow(expectedStringValue, expectedIntValue, []string{"Three", "Four"})) + row := db.QueryRow(query) + var stringValue string + var intValue int + var arrayValue []string + if e := row.Scan(&stringValue, &intValue, &arrayValue); e != nil { + t.Error(e) + } + if stringValue != expectedStringValue { + t.Errorf("Expectation %s does not met: %s", expectedStringValue, stringValue) + } + if intValue != expectedIntValue { + t.Errorf("Expectation %d does not met: %d", expectedIntValue, intValue) + } + if !reflect.DeepEqual(expectedArrayValue, arrayValue) { + t.Errorf("Expectation %v does not met: %v", expectedArrayValue, arrayValue) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Error(err) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d81db53 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/pashagolub/pgxmock + +go 1.15 diff --git a/options.go b/options.go new file mode 100644 index 0000000..00c9837 --- /dev/null +++ b/options.go @@ -0,0 +1,38 @@ +package sqlmock + +import "database/sql/driver" + +// ValueConverterOption allows to create a sqlmock connection +// with a custom ValueConverter to support drivers with special data types. +func ValueConverterOption(converter driver.ValueConverter) func(*sqlmock) error { + return func(s *sqlmock) error { + s.converter = converter + return nil + } +} + +// QueryMatcherOption allows to customize SQL query matcher +// and match SQL query strings in more sophisticated ways. +// The default QueryMatcher is QueryMatcherRegexp. +func QueryMatcherOption(queryMatcher QueryMatcher) func(*sqlmock) error { + return func(s *sqlmock) error { + s.queryMatcher = queryMatcher + return nil + } +} + +// MonitorPingsOption determines whether calls to Ping on the driver should be +// observed and mocked. +// +// If true is passed, we will check these calls were expected. Expectations can +// be registered using the ExpectPing() method on the mock. +// +// If false is passed or this option is omitted, calls to Ping will not be +// considered when determining expectations and calls to ExpectPing will have +// no effect. +func MonitorPingsOption(monitorPings bool) func(*sqlmock) error { + return func(s *sqlmock) error { + s.monitorPings = monitorPings + return nil + } +} diff --git a/query.go b/query.go new file mode 100644 index 0000000..47d3796 --- /dev/null +++ b/query.go @@ -0,0 +1,68 @@ +package sqlmock + +import ( + "fmt" + "regexp" + "strings" +) + +var re = regexp.MustCompile("\\s+") + +// strip out new lines and trim spaces +func stripQuery(q string) (s string) { + return strings.TrimSpace(re.ReplaceAllString(q, " ")) +} + +// QueryMatcher is an SQL query string matcher interface, +// which can be used to customize validation of SQL query strings. +// As an example, external library could be used to build +// and validate SQL ast, columns selected. +// +// sqlmock can be customized to implement a different QueryMatcher +// configured through an option when sqlmock.New or sqlmock.NewWithDSN +// is called, default QueryMatcher is QueryMatcherRegexp. +type QueryMatcher interface { + + // Match expected SQL query string without whitespace to + // actual SQL. + Match(expectedSQL, actualSQL string) error +} + +// QueryMatcherFunc type is an adapter to allow the use of +// ordinary functions as QueryMatcher. If f is a function +// with the appropriate signature, QueryMatcherFunc(f) is a +// QueryMatcher that calls f. +type QueryMatcherFunc func(expectedSQL, actualSQL string) error + +// Match implements the QueryMatcher +func (f QueryMatcherFunc) Match(expectedSQL, actualSQL string) error { + return f(expectedSQL, actualSQL) +} + +// QueryMatcherRegexp is the default SQL query matcher +// used by sqlmock. It parses expectedSQL to a regular +// expression and attempts to match actualSQL. +var QueryMatcherRegexp QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error { + expect := stripQuery(expectedSQL) + actual := stripQuery(actualSQL) + re, err := regexp.Compile(expect) + if err != nil { + return err + } + if !re.MatchString(actual) { + return fmt.Errorf(`could not match actual sql: "%s" with expected regexp "%s"`, actual, re.String()) + } + return nil +}) + +// QueryMatcherEqual is the SQL query matcher +// which simply tries a case sensitive match of +// expected and actual SQL strings without whitespace. +var QueryMatcherEqual QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error { + expect := stripQuery(expectedSQL) + actual := stripQuery(actualSQL) + if actual != expect { + return fmt.Errorf(`actual sql: "%s" does not equal to expected "%s"`, actual, expect) + } + return nil +}) diff --git a/query_test.go b/query_test.go new file mode 100644 index 0000000..0ba7bdc --- /dev/null +++ b/query_test.go @@ -0,0 +1,123 @@ +package sqlmock + +import ( + "fmt" + "testing" +) + +func ExampleQueryMatcher() { + // configure to use case sensitive SQL query matcher + // instead of default regular expression matcher + db, mock, err := New(QueryMatcherOption(QueryMatcherEqual)) + if err != nil { + fmt.Println("failed to open sqlmock database:", err) + } + defer db.Close() + + rows := NewRows([]string{"id", "title"}). + AddRow(1, "one"). + AddRow(2, "two") + + mock.ExpectQuery("SELECT * FROM users").WillReturnRows(rows) + + rs, err := db.Query("SELECT * FROM users") + if err != nil { + fmt.Println("failed to match expected query") + return + } + defer rs.Close() + + for rs.Next() { + var id int + var title string + rs.Scan(&id, &title) + fmt.Println("scanned id:", id, "and title:", title) + } + + if rs.Err() != nil { + fmt.Println("got rows error:", rs.Err()) + } + // Output: scanned id: 1 and title: one + // scanned id: 2 and title: two +} + +func TestQueryStringStripping(t *testing.T) { + assert := func(actual, expected string) { + if res := stripQuery(actual); res != expected { + t.Errorf("Expected '%s' to be '%s', but got '%s'", actual, expected, res) + } + } + + assert(" SELECT 1", "SELECT 1") + assert("SELECT 1 FROM d", "SELECT 1 FROM d") + assert(` + SELECT c + FROM D +`, "SELECT c FROM D") + assert("UPDATE (.+) SET ", "UPDATE (.+) SET") +} + +func TestQueryMatcherRegexp(t *testing.T) { + type testCase struct { + expected string + actual string + err error + } + + cases := []testCase{ + {"?\\l", "SEL", fmt.Errorf("error parsing regexp: missing argument to repetition operator: `?`")}, + {"SELECT (.+) FROM users", "SELECT name, email FROM users WHERE id = ?", nil}, + {"Select (.+) FROM users", "SELECT name, email FROM users WHERE id = ?", fmt.Errorf(`could not match actual sql: "SELECT name, email FROM users WHERE id = ?" with expected regexp "Select (.+) FROM users"`)}, + {"SELECT (.+) FROM\nusers", "SELECT name, email\n FROM users\n WHERE id = ?", nil}, + } + + for i, c := range cases { + err := QueryMatcherRegexp.Match(c.expected, c.actual) + if err == nil && c.err != nil { + t.Errorf(`got no error, but expected "%v" at %d case`, c.err, i) + continue + } + if err != nil && c.err == nil { + t.Errorf(`got unexpected error "%v" at %d case`, err, i) + continue + } + if err == nil { + continue + } + if err.Error() != c.err.Error() { + t.Errorf(`expected error "%v", but got "%v" at %d case`, c.err, err, i) + } + } +} + +func TestQueryMatcherEqual(t *testing.T) { + type testCase struct { + expected string + actual string + err error + } + + cases := []testCase{ + {"SELECT name, email FROM users WHERE id = ?", "SELECT name, email\n FROM users\n WHERE id = ?", nil}, + {"SELECT", "Select", fmt.Errorf(`actual sql: "Select" does not equal to expected "SELECT"`)}, + {"SELECT from users", "SELECT from table", fmt.Errorf(`actual sql: "SELECT from table" does not equal to expected "SELECT from users"`)}, + } + + for i, c := range cases { + err := QueryMatcherEqual.Match(c.expected, c.actual) + if err == nil && c.err != nil { + t.Errorf(`got no error, but expected "%v" at %d case`, c.err, i) + continue + } + if err != nil && c.err == nil { + t.Errorf(`got unexpected error "%v" at %d case`, err, i) + continue + } + if err == nil { + continue + } + if err.Error() != c.err.Error() { + t.Errorf(`expected error "%v", but got "%v" at %d case`, c.err, err, i) + } + } +} diff --git a/result.go b/result.go new file mode 100644 index 0000000..a63e72b --- /dev/null +++ b/result.go @@ -0,0 +1,39 @@ +package sqlmock + +import ( + "database/sql/driver" +) + +// Result satisfies sql driver Result, which +// holds last insert id and rows affected +// by Exec queries +type result struct { + insertID int64 + rowsAffected int64 + err error +} + +// NewResult creates a new sql driver Result +// for Exec based query mocks. +func NewResult(lastInsertID int64, rowsAffected int64) driver.Result { + return &result{ + insertID: lastInsertID, + rowsAffected: rowsAffected, + } +} + +// NewErrorResult creates a new sql driver Result +// which returns an error given for both interface methods +func NewErrorResult(err error) driver.Result { + return &result{ + err: err, + } +} + +func (r *result) LastInsertId() (int64, error) { + return r.insertID, r.err +} + +func (r *result) RowsAffected() (int64, error) { + return r.rowsAffected, r.err +} diff --git a/result_test.go b/result_test.go new file mode 100644 index 0000000..f4eb815 --- /dev/null +++ b/result_test.go @@ -0,0 +1,62 @@ +package sqlmock + +import ( + "fmt" + "testing" +) + +// used for examples +var mock = &sqlmock{} + +func ExampleNewErrorResult() { + db, mock, _ := New() + result := NewErrorResult(fmt.Errorf("some error")) + mock.ExpectExec("^INSERT (.+)").WillReturnResult(result) + res, _ := db.Exec("INSERT something") + _, err := res.LastInsertId() + fmt.Println(err) + // Output: some error +} + +func ExampleNewResult() { + var lastInsertID, affected int64 + result := NewResult(lastInsertID, affected) + mock.ExpectExec("^INSERT (.+)").WillReturnResult(result) + fmt.Println(mock.ExpectationsWereMet()) + // Output: there is a remaining expectation which was not matched: ExpectedExec => expecting Exec or ExecContext which: + // - matches sql: '^INSERT (.+)' + // - is without arguments + // - should return Result having: + // LastInsertId: 0 + // RowsAffected: 0 +} + +func TestShouldReturnValidSqlDriverResult(t *testing.T) { + result := NewResult(1, 2) + id, err := result.LastInsertId() + if 1 != id { + t.Errorf("expected last insert id to be 1, but got: %d", id) + } + if err != nil { + t.Errorf("expected no error, but got: %s", err) + } + affected, err := result.RowsAffected() + if 2 != affected { + t.Errorf("expected affected rows to be 2, but got: %d", affected) + } + if err != nil { + t.Errorf("expected no error, but got: %s", err) + } +} + +func TestShouldReturnErrorSqlDriverResult(t *testing.T) { + result := NewErrorResult(fmt.Errorf("some error")) + _, err := result.LastInsertId() + if err == nil { + t.Error("expected error, but got none") + } + _, err = result.RowsAffected() + if err == nil { + t.Error("expected error, but got none") + } +} diff --git a/rows.go b/rows.go new file mode 100644 index 0000000..ccc5f0c --- /dev/null +++ b/rows.go @@ -0,0 +1,212 @@ +package sqlmock + +import ( + "bytes" + "database/sql/driver" + "encoding/csv" + "fmt" + "io" + "strings" +) + +const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ " + +// CSVColumnParser is a function which converts trimmed csv +// column string to a []byte representation. Currently +// transforms NULL to nil +var CSVColumnParser = func(s string) []byte { + switch { + case strings.ToLower(s) == "null": + return nil + } + return []byte(s) +} + +type rowSets struct { + sets []*Rows + pos int + ex *ExpectedQuery + raw [][]byte +} + +func (rs *rowSets) Columns() []string { + return rs.sets[rs.pos].cols +} + +func (rs *rowSets) Close() error { + rs.invalidateRaw() + rs.ex.rowsWereClosed = true + return rs.sets[rs.pos].closeErr +} + +// advances to next row +func (rs *rowSets) Next(dest []driver.Value) error { + r := rs.sets[rs.pos] + r.pos++ + rs.invalidateRaw() + if r.pos > len(r.rows) { + return io.EOF // per interface spec + } + + for i, col := range r.rows[r.pos-1] { + if b, ok := rawBytes(col); ok { + rs.raw = append(rs.raw, b) + dest[i] = b + continue + } + dest[i] = col + } + + return r.nextErr[r.pos-1] +} + +// transforms to debuggable printable string +func (rs *rowSets) String() string { + if rs.empty() { + return "with empty rows" + } + + msg := "should return rows:\n" + if len(rs.sets) == 1 { + for n, row := range rs.sets[0].rows { + msg += fmt.Sprintf(" row %d - %+v\n", n, row) + } + return strings.TrimSpace(msg) + } + for i, set := range rs.sets { + msg += fmt.Sprintf(" result set: %d\n", i) + for n, row := range set.rows { + msg += fmt.Sprintf(" row %d - %+v\n", n, row) + } + } + return strings.TrimSpace(msg) +} + +func (rs *rowSets) empty() bool { + for _, set := range rs.sets { + if len(set.rows) > 0 { + return false + } + } + return true +} + +func rawBytes(col driver.Value) (_ []byte, ok bool) { + val, ok := col.([]byte) + if !ok || len(val) == 0 { + return nil, false + } + // Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later + // This allows scanning into sql.RawBytes to correctly become invalid on subsequent calls to Next(), Scan() or Close() + b := make([]byte, len(val)) + copy(b, val) + return b, true +} + +// Bytes that could have been scanned as sql.RawBytes are only valid until the next call to Next, Scan or Close. +// If those occur, we must replace their content to simulate the shared memory to expose misuse of sql.RawBytes +func (rs *rowSets) invalidateRaw() { + // Replace the content of slices previously returned + b := []byte(invalidate) + for _, r := range rs.raw { + copy(r, bytes.Repeat(b, len(r)/len(b)+1)) + } + // Start with new slices for the next scan + rs.raw = nil +} + +// Rows is a mocked collection of rows to +// return for Query result +type Rows struct { + converter driver.ValueConverter + cols []string + def []*Column + rows [][]driver.Value + pos int + nextErr map[int]error + closeErr error +} + +// NewRows allows Rows to be created from a +// sql driver.Value slice or from the CSV string and +// to be used as sql driver.Rows. +// Use Sqlmock.NewRows instead if using a custom converter +func NewRows(columns []string) *Rows { + return &Rows{ + cols: columns, + nextErr: make(map[int]error), + converter: driver.DefaultParameterConverter, + } +} + +// CloseError allows to set an error +// which will be returned by rows.Close +// function. +// +// The close error will be triggered only in cases +// when rows.Next() EOF was not yet reached, that is +// a default sql library behavior +func (r *Rows) CloseError(err error) *Rows { + r.closeErr = err + return r +} + +// RowError allows to set an error +// which will be returned when a given +// row number is read +func (r *Rows) RowError(row int, err error) *Rows { + r.nextErr[row] = err + return r +} + +// AddRow composed from database driver.Value slice +// return the same instance to perform subsequent actions. +// Note that the number of values must match the number +// of columns +func (r *Rows) AddRow(values ...driver.Value) *Rows { + if len(values) != len(r.cols) { + panic("Expected number of values to match number of columns") + } + + row := make([]driver.Value, len(r.cols)) + for i, v := range values { + // Convert user-friendly values (such as int or driver.Valuer) + // to database/sql native value (driver.Value such as int64) + var err error + v, err = r.converter.ConvertValue(v) + if err != nil { + panic(fmt.Errorf( + "row #%d, column #%d (%q) type %T: %s", + len(r.rows)+1, i, r.cols[i], values[i], err, + )) + } + + row[i] = v + } + + r.rows = append(r.rows, row) + return r +} + +// FromCSVString build rows from csv string. +// return the same instance to perform subsequent actions. +// Note that the number of values must match the number +// of columns +func (r *Rows) FromCSVString(s string) *Rows { + res := strings.NewReader(strings.TrimSpace(s)) + csvReader := csv.NewReader(res) + + for { + res, err := csvReader.Read() + if err != nil || res == nil { + break + } + + row := make([]driver.Value, len(r.cols)) + for i, v := range res { + row[i] = CSVColumnParser(strings.TrimSpace(v)) + } + r.rows = append(r.rows, row) + } + return r +} diff --git a/rows_go13_test.go b/rows_go13_test.go new file mode 100644 index 0000000..5c9038c --- /dev/null +++ b/rows_go13_test.go @@ -0,0 +1,31 @@ +// +build go1.3 + +package sqlmock + +import ( + "database/sql" + "testing" +) + +func TestQueryRowBytesNotInvalidatedByNext_stringIntoRawBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}). + AddRow(`one binary value with some text!`). + AddRow(`two binary value with even more text than the first one`) + scan := func(rs *sql.Rows) ([]byte, error) { + var raw sql.RawBytes + return raw, rs.Scan(&raw) + } + want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} + queryRowBytesNotInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByClose_stringIntoRawBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}).AddRow(`one binary value with some text!`) + scan := func(rs *sql.Rows) ([]byte, error) { + var raw sql.RawBytes + return raw, rs.Scan(&raw) + } + queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) +} diff --git a/rows_go18.go b/rows_go18.go new file mode 100644 index 0000000..6c71eb9 --- /dev/null +++ b/rows_go18.go @@ -0,0 +1,74 @@ +// +build go1.8 + +package sqlmock + +import ( + "database/sql/driver" + "io" + "reflect" +) + +// Implement the "RowsNextResultSet" interface +func (rs *rowSets) HasNextResultSet() bool { + return rs.pos+1 < len(rs.sets) +} + +// Implement the "RowsNextResultSet" interface +func (rs *rowSets) NextResultSet() error { + if !rs.HasNextResultSet() { + return io.EOF + } + + rs.pos++ + return nil +} + +// type for rows with columns definition created with sqlmock.NewRowsWithColumnDefinition +type rowSetsWithDefinition struct { + *rowSets +} + +// Implement the "RowsColumnTypeDatabaseTypeName" interface +func (rs *rowSetsWithDefinition) ColumnTypeDatabaseTypeName(index int) string { + return rs.getDefinition(index).DbType() +} + +// Implement the "RowsColumnTypeLength" interface +func (rs *rowSetsWithDefinition) ColumnTypeLength(index int) (length int64, ok bool) { + return rs.getDefinition(index).Length() +} + +// Implement the "RowsColumnTypeNullable" interface +func (rs *rowSetsWithDefinition) ColumnTypeNullable(index int) (nullable, ok bool) { + return rs.getDefinition(index).IsNullable() +} + +// Implement the "RowsColumnTypePrecisionScale" interface +func (rs *rowSetsWithDefinition) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + return rs.getDefinition(index).PrecisionScale() +} + +// ColumnTypeScanType is defined from driver.RowsColumnTypeScanType +func (rs *rowSetsWithDefinition) ColumnTypeScanType(index int) reflect.Type { + return rs.getDefinition(index).ScanType() +} + +// return column definition from current set metadata +func (rs *rowSetsWithDefinition) getDefinition(index int) *Column { + return rs.sets[rs.pos].def[index] +} + +// NewRowsWithColumnDefinition return rows with columns metadata +func NewRowsWithColumnDefinition(columns ...*Column) *Rows { + cols := make([]string, len(columns)) + for i, column := range columns { + cols[i] = column.Name() + } + + return &Rows{ + cols: cols, + def: columns, + nextErr: make(map[int]error), + converter: driver.DefaultParameterConverter, + } +} diff --git a/rows_go18_test.go b/rows_go18_test.go new file mode 100644 index 0000000..0af6d66 --- /dev/null +++ b/rows_go18_test.go @@ -0,0 +1,387 @@ +// +build go1.8 + +package sqlmock + +import ( + "database/sql" + "encoding/json" + "fmt" + "reflect" + "testing" + "time" +) + +func TestQueryMultiRows(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + rs2 := NewRows([]string{"name"}).AddRow("gopher").AddRow("john").AddRow("jane").RowError(2, fmt.Errorf("error")) + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = \\?;SELECT name FROM users"). + WithArgs(5). + WillReturnRows(rs1, rs2) + + rows, err := db.Query("SELECT id, title FROM articles WHERE id = ?;SELECT name FROM users", 5) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + defer rows.Close() + + if !rows.Next() { + t.Error("expected a row to be available in first result set") + } + + var id int + var name string + + err = rows.Scan(&id, &name) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if id != 5 || name != "hello world" { + t.Errorf("unexpected row values id: %v name: %v", id, name) + } + + if rows.Next() { + t.Error("was not expecting next row in first result set") + } + + if !rows.NextResultSet() { + t.Error("had to have next result set") + } + + if !rows.Next() { + t.Error("expected a row to be available in second result set") + } + + err = rows.Scan(&name) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if name != "gopher" { + t.Errorf("unexpected row name: %v", name) + } + + if !rows.Next() { + t.Error("expected a row to be available in second result set") + } + + err = rows.Scan(&name) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if name != "john" { + t.Errorf("unexpected row name: %v", name) + } + + if rows.Next() { + t.Error("expected next row to produce error") + } + + if rows.Err() == nil { + t.Error("expected an error, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestQueryRowBytesInvalidatedByNext_jsonRawMessageIntoRawBytes(t *testing.T) { + t.Parallel() + replace := []byte(invalid) + rows := NewRows([]string{"raw"}). + AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)). + AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var raw sql.RawBytes + return raw, rs.Scan(&raw) + } + want := []struct { + Initial []byte + Replaced []byte + }{ + {Initial: []byte(`{"thing": "one", "thing2": "two"}`), Replaced: replace[:len(replace)-6]}, + {Initial: []byte(`{"that": "foo", "this": "bar"}`), Replaced: replace[:len(replace)-9]}, + } + queryRowBytesInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}). + AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)). + AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var b []byte + return b, rs.Scan(&b) + } + want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)} + queryRowBytesNotInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByNext_bytesIntoCustomBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}). + AddRow([]byte(`one binary value with some text!`)). + AddRow([]byte(`two binary value with even more text than the first one`)) + scan := func(rs *sql.Rows) ([]byte, error) { + type customBytes []byte + var b customBytes + return b, rs.Scan(&b) + } + want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} + queryRowBytesNotInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoCustomBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}). + AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)). + AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`)) + scan := func(rs *sql.Rows) ([]byte, error) { + type customBytes []byte + var b customBytes + return b, rs.Scan(&b) + } + want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)} + queryRowBytesNotInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByClose_bytesIntoCustomBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`)) + scan := func(rs *sql.Rows) ([]byte, error) { + type customBytes []byte + var b customBytes + return b, rs.Scan(&b) + } + queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) +} + +func TestQueryRowBytesInvalidatedByClose_jsonRawMessageIntoRawBytes(t *testing.T) { + t.Parallel() + replace := []byte(invalid) + rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var raw sql.RawBytes + return raw, rs.Scan(&raw) + } + want := struct { + Initial []byte + Replaced []byte + }{ + Initial: []byte(`{"thing": "one", "thing2": "two"}`), + Replaced: replace[:len(replace)-6], + } + queryRowBytesInvalidatedByClose(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var b []byte + return b, rs.Scan(&b) + } + queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`)) +} + +func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoCustomBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)) + scan := func(rs *sql.Rows) ([]byte, error) { + type customBytes []byte + var b customBytes + return b, rs.Scan(&b) + } + queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`)) +} + +func TestNewColumnWithDefinition(t *testing.T) { + now, _ := time.Parse(time.RFC3339, "2020-06-20T22:08:41Z") + + t.Run("with one ResultSet", func(t *testing.T) { + db, mock, _ := New() + column1 := mock.NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100) + column2 := mock.NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4) + column3 := mock.NewColumn("when").OfType("TIMESTAMP", now) + rows := mock.NewRowsWithColumnDefinition(column1, column2, column3) + rows.AddRow("foo.bar", float64(10.123), now) + + mQuery := mock.ExpectQuery("SELECT test, number, when from dummy") + isQuery := mQuery.WillReturnRows(rows) + isQueryClosed := mQuery.RowsWillBeClosed() + isDbClosed := mock.ExpectClose() + + query, _ := db.Query("SELECT test, number, when from dummy") + + if false == isQuery.fulfilled() { + t.Error("Query is not executed") + } + + if query.Next() { + var test string + var number float64 + var when time.Time + + if queryError := query.Scan(&test, &number, &when); queryError != nil { + t.Error(queryError) + } else if test != "foo.bar" { + t.Error("field test is not 'foo.bar'") + } else if number != float64(10.123) { + t.Error("field number is not '10.123'") + } else if when != now { + t.Errorf("field when is not %v", now) + } + + if columnTypes, colTypErr := query.ColumnTypes(); colTypErr != nil { + t.Error(colTypErr) + } else if len(columnTypes) != 3 { + t.Error("number of columnTypes") + } else if name := columnTypes[0].Name(); name != "test" { + t.Errorf("field 'test' has a wrong name '%s'", name) + } else if dbType := columnTypes[0].DatabaseTypeName(); dbType != "VARCHAR" { + t.Errorf("field 'test' has a wrong db type '%s'", dbType) + } else if columnTypes[0].ScanType().Kind() != reflect.String { + t.Error("field 'test' has a wrong scanType") + } else if _, _, ok := columnTypes[0].DecimalSize(); ok { + t.Error("field 'test' should have not precision, scale") + } else if length, ok := columnTypes[0].Length(); length != 100 || !ok { + t.Errorf("field 'test' has a wrong length '%d'", length) + } else if name := columnTypes[1].Name(); name != "number" { + t.Errorf("field 'number' has a wrong name '%s'", name) + } else if dbType := columnTypes[1].DatabaseTypeName(); dbType != "DECIMAL" { + t.Errorf("field 'number' has a wrong db type '%s'", dbType) + } else if columnTypes[1].ScanType().Kind() != reflect.Float64 { + t.Error("field 'number' has a wrong scanType") + } else if precision, scale, ok := columnTypes[1].DecimalSize(); precision != int64(10) || scale != int64(4) || !ok { + t.Error("field 'number' has a wrong precision, scale") + } else if _, ok := columnTypes[1].Length(); ok { + t.Error("field 'number' is not variable length type") + } else if _, ok := columnTypes[2].Nullable(); ok { + t.Error("field 'when' should have nullability unknown") + } + } else { + t.Error("no result set") + } + + query.Close() + if false == isQueryClosed.fulfilled() { + t.Error("Query is not executed") + } + + db.Close() + if false == isDbClosed.fulfilled() { + t.Error("Db is not closed") + } + }) + + t.Run("with more then one ResultSet", func(t *testing.T) { + db, mock, _ := New() + column1 := mock.NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100) + column2 := mock.NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4) + column3 := mock.NewColumn("when").OfType("TIMESTAMP", now) + rows1 := mock.NewRowsWithColumnDefinition(column1, column2, column3) + rows1.AddRow("foo.bar", float64(10.123), now) + rows2 := mock.NewRowsWithColumnDefinition(column1, column2, column3) + rows2.AddRow("bar.foo", float64(123.10), now.Add(time.Second*10)) + rows3 := mock.NewRowsWithColumnDefinition(column1, column2, column3) + rows3.AddRow("lollipop", float64(10.321), now.Add(time.Second*20)) + + mQuery := mock.ExpectQuery("SELECT test, number, when from dummy") + isQuery := mQuery.WillReturnRows(rows1, rows2, rows3) + isQueryClosed := mQuery.RowsWillBeClosed() + isDbClosed := mock.ExpectClose() + + query, _ := db.Query("SELECT test, number, when from dummy") + + if false == isQuery.fulfilled() { + t.Error("Query is not executed") + } + + rowsSi := 0 + + for query.Next() { + var test string + var number float64 + var when time.Time + + if queryError := query.Scan(&test, &number, &when); queryError != nil { + t.Error(queryError) + + } else if rowsSi == 0 && test != "foo.bar" { + t.Error("field test is not 'foo.bar'") + } else if rowsSi == 0 && number != float64(10.123) { + t.Error("field number is not '10.123'") + } else if rowsSi == 0 && when != now { + t.Errorf("field when is not %v", now) + + } else if rowsSi == 1 && test != "bar.foo" { + t.Error("field test is not 'bar.bar'") + } else if rowsSi == 1 && number != float64(123.10) { + t.Error("field number is not '123.10'") + } else if rowsSi == 1 && when != now.Add(time.Second*10) { + t.Errorf("field when is not %v", now) + + } else if rowsSi == 2 && test != "lollipop" { + t.Error("field test is not 'lollipop'") + } else if rowsSi == 2 && number != float64(10.321) { + t.Error("field number is not '10.321'") + } else if rowsSi == 2 && when != now.Add(time.Second*20) { + t.Errorf("field when is not %v", now) + } + + rowsSi++ + + if columnTypes, colTypErr := query.ColumnTypes(); colTypErr != nil { + t.Error(colTypErr) + } else if len(columnTypes) != 3 { + t.Error("number of columnTypes") + } else if name := columnTypes[0].Name(); name != "test" { + t.Errorf("field 'test' has a wrong name '%s'", name) + } else if dbType := columnTypes[0].DatabaseTypeName(); dbType != "VARCHAR" { + t.Errorf("field 'test' has a wrong db type '%s'", dbType) + } else if columnTypes[0].ScanType().Kind() != reflect.String { + t.Error("field 'test' has a wrong scanType") + } else if _, _, ok := columnTypes[0].DecimalSize(); ok { + t.Error("field 'test' should not have precision, scale") + } else if length, ok := columnTypes[0].Length(); length != 100 || !ok { + t.Errorf("field 'test' has a wrong length '%d'", length) + } else if name := columnTypes[1].Name(); name != "number" { + t.Errorf("field 'number' has a wrong name '%s'", name) + } else if dbType := columnTypes[1].DatabaseTypeName(); dbType != "DECIMAL" { + t.Errorf("field 'number' has a wrong db type '%s'", dbType) + } else if columnTypes[1].ScanType().Kind() != reflect.Float64 { + t.Error("field 'number' has a wrong scanType") + } else if precision, scale, ok := columnTypes[1].DecimalSize(); precision != int64(10) || scale != int64(4) || !ok { + t.Error("field 'number' has a wrong precision, scale") + } else if _, ok := columnTypes[1].Length(); ok { + t.Error("field 'number' is not variable length type") + } else if _, ok := columnTypes[2].Nullable(); ok { + t.Error("field 'when' should have nullability unknown") + } + } + if rowsSi == 0 { + t.Error("no result set") + } + + query.Close() + if false == isQueryClosed.fulfilled() { + t.Error("Query is not executed") + } + + db.Close() + if false == isDbClosed.fulfilled() { + t.Error("Db is not closed") + } + }) +} diff --git a/rows_test.go b/rows_test.go new file mode 100644 index 0000000..15cdbee --- /dev/null +++ b/rows_test.go @@ -0,0 +1,672 @@ +package sqlmock + +import ( + "bytes" + "database/sql" + "fmt" + "testing" +) + +const invalid = `☠☠☠ MEMORY OVERWRITTEN ☠☠☠ ` + +func ExampleRows() { + db, mock, err := New() + if err != nil { + fmt.Println("failed to open sqlmock database:", err) + } + defer db.Close() + + rows := NewRows([]string{"id", "title"}). + AddRow(1, "one"). + AddRow(2, "two") + + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, _ := db.Query("SELECT") + defer rs.Close() + + for rs.Next() { + var id int + var title string + rs.Scan(&id, &title) + fmt.Println("scanned id:", id, "and title:", title) + } + + if rs.Err() != nil { + fmt.Println("got rows error:", rs.Err()) + } + // Output: scanned id: 1 and title: one + // scanned id: 2 and title: two +} + +func ExampleRows_rowError() { + db, mock, err := New() + if err != nil { + fmt.Println("failed to open sqlmock database:", err) + } + defer db.Close() + + rows := NewRows([]string{"id", "title"}). + AddRow(0, "one"). + AddRow(1, "two"). + RowError(1, fmt.Errorf("row error")) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, _ := db.Query("SELECT") + defer rs.Close() + + for rs.Next() { + var id int + var title string + rs.Scan(&id, &title) + fmt.Println("scanned id:", id, "and title:", title) + } + + if rs.Err() != nil { + fmt.Println("got rows error:", rs.Err()) + } + // Output: scanned id: 0 and title: one + // got rows error: row error +} + +func ExampleRows_closeError() { + db, mock, err := New() + if err != nil { + fmt.Println("failed to open sqlmock database:", err) + } + defer db.Close() + + rows := NewRows([]string{"id", "title"}).CloseError(fmt.Errorf("close error")) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, _ := db.Query("SELECT") + + // Note: that close will return error only before rows EOF + // that is a default sql package behavior. If you run rs.Next() + // it will handle the error internally and return nil bellow + if err := rs.Close(); err != nil { + fmt.Println("got error:", err) + } + + // Output: got error: close error +} + +func ExampleRows_rawBytes() { + db, mock, err := New() + if err != nil { + fmt.Println("failed to open sqlmock database:", err) + } + defer db.Close() + + rows := NewRows([]string{"id", "binary"}). + AddRow(1, []byte(`one binary value with some text!`)). + AddRow(2, []byte(`two binary value with even more text than the first one`)) + + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, _ := db.Query("SELECT") + defer rs.Close() + + type scanned struct { + id int + raw sql.RawBytes + } + fmt.Println("initial read...") + var ss []scanned + for rs.Next() { + var s scanned + rs.Scan(&s.id, &s.raw) + ss = append(ss, s) + fmt.Println("scanned id:", s.id, "and raw:", string(s.raw)) + } + + if rs.Err() != nil { + fmt.Println("got rows error:", rs.Err()) + } + + fmt.Println("after reading all...") + for _, s := range ss { + fmt.Println("scanned id:", s.id, "and raw:", string(s.raw)) + } + // Output: + // initial read... + // scanned id: 1 and raw: one binary value with some text! + // scanned id: 2 and raw: two binary value with even more text than the first one + // after reading all... + // scanned id: 1 and raw: ☠☠☠ MEMORY OVERWRITTEN ☠ + // scanned id: 2 and raw: ☠☠☠ MEMORY OVERWRITTEN ☠☠☠ ☠☠☠ MEMORY +} + +func ExampleRows_expectToBeClosed() { + db, mock, err := New() + if err != nil { + fmt.Println("failed to open sqlmock database:", err) + } + defer db.Close() + + rows := NewRows([]string{"id", "title"}).AddRow(1, "john") + mock.ExpectQuery("SELECT").WillReturnRows(rows).RowsWillBeClosed() + + db.Query("SELECT") + + if err := mock.ExpectationsWereMet(); err != nil { + fmt.Println("got error:", err) + } + + // Output: got error: expected query rows to be closed, but it was not: ExpectedQuery => expecting Query, QueryContext or QueryRow which: + // - matches sql: 'SELECT' + // - is without arguments + // - should return rows: + // row 0 - [1 john] +} + +func ExampleRows_customDriverValue() { + db, mock, err := New() + if err != nil { + fmt.Println("failed to open sqlmock database:", err) + } + defer db.Close() + + rows := NewRows([]string{"id", "null_int"}). + AddRow(1, 7). + AddRow(5, sql.NullInt64{Int64: 5, Valid: true}). + AddRow(2, sql.NullInt64{}) + + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, _ := db.Query("SELECT") + defer rs.Close() + + for rs.Next() { + var id int + var num sql.NullInt64 + rs.Scan(&id, &num) + fmt.Println("scanned id:", id, "and null int64:", num) + } + + if rs.Err() != nil { + fmt.Println("got rows error:", rs.Err()) + } + // Output: scanned id: 1 and null int64: {7 true} + // scanned id: 5 and null int64: {5 true} + // scanned id: 2 and null int64: {0 false} +} + +func TestAllowsToSetRowsErrors(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rows := NewRows([]string{"id", "title"}). + AddRow(0, "one"). + AddRow(1, "two"). + RowError(1, fmt.Errorf("error")) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, err := db.Query("SELECT") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer rs.Close() + + if !rs.Next() { + t.Fatal("expected the first row to be available") + } + if rs.Err() != nil { + t.Fatalf("unexpected error: %s", rs.Err()) + } + + if rs.Next() { + t.Fatal("was not expecting the second row, since there should be an error") + } + if rs.Err() == nil { + t.Fatal("expected an error, but got none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +} + +func TestRowsCloseError(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rows := NewRows([]string{"id"}).CloseError(fmt.Errorf("close error")) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, err := db.Query("SELECT") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if err := rs.Close(); err == nil { + t.Fatal("expected a close error") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +} + +func TestRowsClosed(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rows := NewRows([]string{"id"}).AddRow(1) + mock.ExpectQuery("SELECT").WillReturnRows(rows).RowsWillBeClosed() + + rs, err := db.Query("SELECT") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if err := rs.Close(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +} + +func TestQuerySingleRow(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rows := NewRows([]string{"id"}). + AddRow(1). + AddRow(2) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + var id int + if err := db.QueryRow("SELECT").Scan(&id); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + mock.ExpectQuery("SELECT").WillReturnRows(NewRows([]string{"id"})) + if err := db.QueryRow("SELECT").Scan(&id); err != sql.ErrNoRows { + t.Fatal("expected sql no rows error") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +} + +func TestQueryRowBytesInvalidatedByNext_bytesIntoRawBytes(t *testing.T) { + t.Parallel() + replace := []byte(invalid) + rows := NewRows([]string{"raw"}). + AddRow([]byte(`one binary value with some text!`)). + AddRow([]byte(`two binary value with even more text than the first one`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var raw sql.RawBytes + return raw, rs.Scan(&raw) + } + want := []struct { + Initial []byte + Replaced []byte + }{ + {Initial: []byte(`one binary value with some text!`), Replaced: replace[:len(replace)-7]}, + {Initial: []byte(`two binary value with even more text than the first one`), Replaced: bytes.Join([][]byte{replace, replace[:len(replace)-23]}, nil)}, + } + queryRowBytesInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByNext_bytesIntoBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}). + AddRow([]byte(`one binary value with some text!`)). + AddRow([]byte(`two binary value with even more text than the first one`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var b []byte + return b, rs.Scan(&b) + } + want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} + queryRowBytesNotInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByNext_stringIntoBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}). + AddRow(`one binary value with some text!`). + AddRow(`two binary value with even more text than the first one`) + scan := func(rs *sql.Rows) ([]byte, error) { + var b []byte + return b, rs.Scan(&b) + } + want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} + queryRowBytesNotInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesInvalidatedByClose_bytesIntoRawBytes(t *testing.T) { + t.Parallel() + replace := []byte(invalid) + rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var raw sql.RawBytes + return raw, rs.Scan(&raw) + } + want := struct { + Initial []byte + Replaced []byte + }{ + Initial: []byte(`one binary value with some text!`), + Replaced: replace[:len(replace)-7], + } + queryRowBytesInvalidatedByClose(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByClose_bytesIntoBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var b []byte + return b, rs.Scan(&b) + } + queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) +} + +func TestQueryRowBytesNotInvalidatedByClose_stringIntoBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}).AddRow(`one binary value with some text!`) + scan := func(rs *sql.Rows) ([]byte, error) { + var b []byte + return b, rs.Scan(&b) + } + queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) +} + +func TestRowsScanError(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + r := NewRows([]string{"col1", "col2"}).AddRow("one", "two").AddRow("one", nil) + mock.ExpectQuery("SELECT").WillReturnRows(r) + + rs, err := db.Query("SELECT") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer rs.Close() + + var one, two string + if !rs.Next() || rs.Err() != nil || rs.Scan(&one, &two) != nil { + t.Fatal("unexpected error on first row scan") + } + + if !rs.Next() || rs.Err() != nil { + t.Fatal("unexpected error on second row read") + } + + err = rs.Scan(&one, &two) + if err == nil { + t.Fatal("expected an error for scan, but got none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +} + +func TestCSVRowParser(t *testing.T) { + t.Parallel() + rs := NewRows([]string{"col1", "col2"}).FromCSVString("a,NULL") + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectQuery("SELECT").WillReturnRows(rs) + + rw, err := db.Query("SELECT") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer rw.Close() + var col1 string + var col2 []byte + + rw.Next() + if err = rw.Scan(&col1, &col2); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if col1 != "a" { + t.Fatalf("expected col1 to be 'a', but got [%T]:%+v", col1, col1) + } + if col2 != nil { + t.Fatalf("expected col2 to be nil, but got [%T]:%+v", col2, col2) + } +} + +func TestWrongNumberOfValues(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + defer db.Close() + defer func() { + recover() + }() + mock.ExpectQuery("SELECT ID FROM TABLE").WithArgs(101).WillReturnRows(NewRows([]string{"ID"}).AddRow(101, "Hello")) + db.Query("SELECT ID FROM TABLE", 101) + // shouldn't reach here + t.Error("expected panic from query") +} + +func TestEmptyRowSets(t *testing.T) { + rs1 := NewRows([]string{"a"}).AddRow("a") + rs2 := NewRows([]string{"b"}) + rs3 := NewRows([]string{"c"}) + + set1 := &rowSets{sets: []*Rows{rs1, rs2}} + set2 := &rowSets{sets: []*Rows{rs3, rs2}} + set3 := &rowSets{sets: []*Rows{rs2}} + + if set1.empty() { + t.Fatalf("expected rowset 1, not to be empty, but it was") + } + if !set2.empty() { + t.Fatalf("expected rowset 2, to be empty, but it was not") + } + if !set3.empty() { + t.Fatalf("expected rowset 3, to be empty, but it was not") + } +} + +func queryRowBytesInvalidatedByNext(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want []struct { + Initial []byte + Replaced []byte +}) { + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, err := db.Query("SELECT") + if err != nil { + t.Fatalf("failed to query rows: %s", err) + } + + if !rs.Next() || rs.Err() != nil { + t.Fatal("unexpected error on first row retrieval") + } + var count int + for i := 0; ; i++ { + count++ + b, err := scan(rs) + if err != nil { + t.Fatalf("unexpected error scanning row: %s", err) + } + if exp := want[i].Initial; !bytes.Equal(b, exp) { + t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) + } + next := rs.Next() + if exp := want[i].Replaced; !bytes.Equal(b, exp) { + t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) + } + if !next { + break + } + } + if err := rs.Err(); err != nil { + t.Fatalf("row iteration failed: %s", err) + } + if exp := len(want); count != exp { + t.Fatalf("incorrect number of rows exp: %d, but got %d", exp, count) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +} + +func queryRowBytesNotInvalidatedByNext(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want [][]byte) { + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, err := db.Query("SELECT") + if err != nil { + t.Fatalf("failed to query rows: %s", err) + } + + if !rs.Next() || rs.Err() != nil { + t.Fatal("unexpected error on first row retrieval") + } + var count int + for i := 0; ; i++ { + count++ + b, err := scan(rs) + if err != nil { + t.Fatalf("unexpected error scanning row: %s", err) + } + if exp := want[i]; !bytes.Equal(b, exp) { + t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) + } + next := rs.Next() + if exp := want[i]; !bytes.Equal(b, exp) { + t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) + } + if !next { + break + } + } + if err := rs.Err(); err != nil { + t.Fatalf("row iteration failed: %s", err) + } + if exp := len(want); count != exp { + t.Fatalf("incorrect number of rows exp: %d, but got %d", exp, count) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +} + +func queryRowBytesInvalidatedByClose(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want struct { + Initial []byte + Replaced []byte +}) { + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, err := db.Query("SELECT") + if err != nil { + t.Fatalf("failed to query rows: %s", err) + } + + if !rs.Next() || rs.Err() != nil { + t.Fatal("unexpected error on first row retrieval") + } + b, err := scan(rs) + if err != nil { + t.Fatalf("unexpected error scanning row: %s", err) + } + if !bytes.Equal(b, want.Initial) { + t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", want.Initial, len(want.Initial), b, b, len(b)) + } + if err := rs.Close(); err != nil { + t.Fatalf("unexpected error closing rows: %s", err) + } + if !bytes.Equal(b, want.Replaced) { + t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", want.Replaced, len(want.Replaced), b, b, len(b)) + } + if err := rs.Err(); err != nil { + t.Fatalf("row iteration failed: %s", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +} + +func queryRowBytesNotInvalidatedByClose(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want []byte) { + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, err := db.Query("SELECT") + if err != nil { + t.Fatalf("failed to query rows: %s", err) + } + + if !rs.Next() || rs.Err() != nil { + t.Fatal("unexpected error on first row retrieval") + } + b, err := scan(rs) + if err != nil { + t.Fatalf("unexpected error scanning row: %s", err) + } + if !bytes.Equal(b, want) { + t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", want, len(want), b, b, len(b)) + } + if err := rs.Close(); err != nil { + t.Fatalf("unexpected error closing rows: %s", err) + } + if !bytes.Equal(b, want) { + t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", want, len(want), b, b, len(b)) + } + if err := rs.Err(); err != nil { + t.Fatalf("row iteration failed: %s", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +} diff --git a/sqlmock.go b/sqlmock.go new file mode 100644 index 0000000..d074266 --- /dev/null +++ b/sqlmock.go @@ -0,0 +1,439 @@ +/* +Package sqlmock is a mock library implementing sql driver. Which has one and only +purpose - to simulate any sql driver behavior in tests, without needing a real +database connection. It helps to maintain correct **TDD** workflow. + +It does not require any modifications to your source code in order to test +and mock database operations. Supports concurrency and multiple database mocking. + +The driver allows to mock any sql driver method behavior. +*/ +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "time" +) + +// Sqlmock interface serves to create expectations +// for any kind of database action in order to mock +// and test real database behavior. +type SqlmockCommon interface { + // ExpectClose queues an expectation for this database + // action to be triggered. the *ExpectedClose allows + // to mock database response + ExpectClose() *ExpectedClose + + // ExpectationsWereMet checks whether all queued expectations + // were met in order. If any of them was not met - an error is returned. + ExpectationsWereMet() error + + // ExpectPrepare expects Prepare() to be called with expectedSQL query. + // the *ExpectedPrepare allows to mock database response. + // Note that you may expect Query() or Exec() on the *ExpectedPrepare + // statement to prevent repeating expectedSQL + ExpectPrepare(expectedSQL string) *ExpectedPrepare + + // ExpectQuery expects Query() or QueryRow() to be called with expectedSQL query. + // the *ExpectedQuery allows to mock database response. + ExpectQuery(expectedSQL string) *ExpectedQuery + + // ExpectExec expects Exec() to be called with expectedSQL query. + // the *ExpectedExec allows to mock database response + ExpectExec(expectedSQL string) *ExpectedExec + + // ExpectBegin expects *sql.DB.Begin to be called. + // the *ExpectedBegin allows to mock database response + ExpectBegin() *ExpectedBegin + + // ExpectCommit expects *sql.Tx.Commit to be called. + // the *ExpectedCommit allows to mock database response + ExpectCommit() *ExpectedCommit + + // ExpectRollback expects *sql.Tx.Rollback to be called. + // the *ExpectedRollback allows to mock database response + ExpectRollback() *ExpectedRollback + + // ExpectPing expected *sql.DB.Ping to be called. + // the *ExpectedPing allows to mock database response + // + // Ping support only exists in the SQL library in Go 1.8 and above. + // ExpectPing in Go <=1.7 will return an ExpectedPing but not register + // any expectations. + // + // You must enable pings using MonitorPingsOption for this to register + // any expectations. + ExpectPing() *ExpectedPing + + // MatchExpectationsInOrder gives an option whether to match all + // expectations in the order they were set or not. + // + // By default it is set to - true. But if you use goroutines + // to parallelize your query executation, that option may + // be handy. + // + // This option may be turned on anytime during tests. As soon + // as it is switched to false, expectations will be matched + // in any order. Or otherwise if switched to true, any unmatched + // expectations will be expected in order + MatchExpectationsInOrder(bool) + + // NewRows allows Rows to be created from a + // sql driver.Value slice or from the CSV string and + // to be used as sql driver.Rows. + NewRows(columns []string) *Rows +} + +type sqlmock struct { + ordered bool + dsn string + opened int + drv *mockDriver + converter driver.ValueConverter + queryMatcher QueryMatcher + monitorPings bool + + expected []expectation +} + +func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) { + db, err := sql.Open("sqlmock", c.dsn) + if err != nil { + return db, c, err + } + for _, option := range options { + err := option(c) + if err != nil { + return db, c, err + } + } + if c.converter == nil { + c.converter = driver.DefaultParameterConverter + } + if c.queryMatcher == nil { + c.queryMatcher = QueryMatcherRegexp + } + + if c.monitorPings { + // We call Ping on the driver shortly to verify startup assertions by + // driving internal behaviour of the sql standard library. We don't + // want this call to ping to be monitored for expectation purposes so + // temporarily disable. + c.monitorPings = false + defer func() { c.monitorPings = true }() + } + return db, c, db.Ping() +} + +func (c *sqlmock) ExpectClose() *ExpectedClose { + e := &ExpectedClose{} + c.expected = append(c.expected, e) + return e +} + +func (c *sqlmock) MatchExpectationsInOrder(b bool) { + c.ordered = b +} + +// Close a mock database driver connection. It may or may not +// be called depending on the circumstances, but if it is called +// there must be an *ExpectedClose expectation satisfied. +// meets http://golang.org/pkg/database/sql/driver/#Conn interface +func (c *sqlmock) Close() error { + c.drv.Lock() + defer c.drv.Unlock() + + c.opened-- + if c.opened == 0 { + delete(c.drv.conns, c.dsn) + } + + var expected *ExpectedClose + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if expected, ok = next.(*ExpectedClose); ok { + break + } + + next.Unlock() + if c.ordered { + return fmt.Errorf("call to database Close, was not expected, next expectation is: %s", next) + } + } + + if expected == nil { + msg := "call to database Close was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return fmt.Errorf(msg) + } + + expected.triggered = true + expected.Unlock() + return expected.err +} + +func (c *sqlmock) ExpectationsWereMet() error { + for _, e := range c.expected { + e.Lock() + fulfilled := e.fulfilled() + e.Unlock() + + if !fulfilled { + return fmt.Errorf("there is a remaining expectation which was not matched: %s", e) + } + + // for expected prepared statement check whether it was closed if expected + if prep, ok := e.(*ExpectedPrepare); ok { + if prep.mustBeClosed && !prep.wasClosed { + return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep) + } + } + + // must check whether all expected queried rows are closed + if query, ok := e.(*ExpectedQuery); ok { + if query.rowsMustBeClosed && !query.rowsWereClosed { + return fmt.Errorf("expected query rows to be closed, but it was not: %s", query) + } + } + } + return nil +} + +// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface +func (c *sqlmock) Begin() (driver.Tx, error) { + ex, err := c.begin() + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return c, nil +} + +func (c *sqlmock) begin() (*ExpectedBegin, error) { + var expected *ExpectedBegin + var ok bool + var fulfilled int + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if expected, ok = next.(*ExpectedBegin); ok { + break + } + + next.Unlock() + if c.ordered { + return nil, fmt.Errorf("call to database transaction Begin, was not expected, next expectation is: %s", next) + } + } + if expected == nil { + msg := "call to database transaction Begin was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg) + } + + expected.triggered = true + expected.Unlock() + + return expected, expected.err +} + +func (c *sqlmock) ExpectBegin() *ExpectedBegin { + e := &ExpectedBegin{} + c.expected = append(c.expected, e) + return e +} + +func (c *sqlmock) ExpectExec(expectedSQL string) *ExpectedExec { + e := &ExpectedExec{} + e.expectSQL = expectedSQL + e.converter = c.converter + c.expected = append(c.expected, e) + return e +} + +// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface +func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { + ex, err := c.prepare(query) + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return &statement{c, ex, query}, nil +} + +func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { + var expected *ExpectedPrepare + var fulfilled int + var ok bool + + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if c.ordered { + if expected, ok = next.(*ExpectedPrepare); ok { + break + } + + next.Unlock() + return nil, fmt.Errorf("call to Prepare statement with query '%s', was not expected, next expectation is: %s", query, next) + } + + if pr, ok := next.(*ExpectedPrepare); ok { + if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil { + expected = pr + break + } + } + next.Unlock() + } + + if expected == nil { + msg := "call to Prepare '%s' query was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg, query) + } + defer expected.Unlock() + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("Prepare: %v", err) + } + + expected.triggered = true + return expected, expected.err +} + +func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare { + e := &ExpectedPrepare{expectSQL: expectedSQL, mock: c} + c.expected = append(c.expected, e) + return e +} + +func (c *sqlmock) ExpectQuery(expectedSQL string) *ExpectedQuery { + e := &ExpectedQuery{} + e.expectSQL = expectedSQL + e.converter = c.converter + c.expected = append(c.expected, e) + return e +} + +func (c *sqlmock) ExpectCommit() *ExpectedCommit { + e := &ExpectedCommit{} + c.expected = append(c.expected, e) + return e +} + +func (c *sqlmock) ExpectRollback() *ExpectedRollback { + e := &ExpectedRollback{} + c.expected = append(c.expected, e) + return e +} + +// Commit meets http://golang.org/pkg/database/sql/driver/#Tx +func (c *sqlmock) Commit() error { + var expected *ExpectedCommit + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if expected, ok = next.(*ExpectedCommit); ok { + break + } + + next.Unlock() + if c.ordered { + return fmt.Errorf("call to Commit transaction, was not expected, next expectation is: %s", next) + } + } + if expected == nil { + msg := "call to Commit transaction was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return fmt.Errorf(msg) + } + + expected.triggered = true + expected.Unlock() + return expected.err +} + +// Rollback meets http://golang.org/pkg/database/sql/driver/#Tx +func (c *sqlmock) Rollback() error { + var expected *ExpectedRollback + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if expected, ok = next.(*ExpectedRollback); ok { + break + } + + next.Unlock() + if c.ordered { + return fmt.Errorf("call to Rollback transaction, was not expected, next expectation is: %s", next) + } + } + if expected == nil { + msg := "call to Rollback transaction was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return fmt.Errorf(msg) + } + + expected.triggered = true + expected.Unlock() + return expected.err +} + +// NewRows allows Rows to be created from a +// sql driver.Value slice or from the CSV string and +// to be used as sql driver.Rows. +func (c *sqlmock) NewRows(columns []string) *Rows { + r := NewRows(columns) + r.converter = c.converter + return r +} diff --git a/sqlmock_before_go18.go b/sqlmock_before_go18.go new file mode 100644 index 0000000..9965e78 --- /dev/null +++ b/sqlmock_before_go18.go @@ -0,0 +1,191 @@ +// +build !go1.8 + +package sqlmock + +import ( + "database/sql/driver" + "fmt" + "log" + "time" +) + +// Sqlmock interface for Go up to 1.7 +type Sqlmock interface { + // Embed common methods + SqlmockCommon +} + +type namedValue struct { + Name string + Ordinal int + Value driver.Value +} + +func (c *sqlmock) ExpectPing() *ExpectedPing { + log.Println("ExpectPing has no effect on Go 1.7 or below") + return &ExpectedPing{} +} + +// Query meets http://golang.org/pkg/database/sql/driver/#Queryer +func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) { + namedArgs := make([]namedValue, len(args)) + for i, v := range args { + namedArgs[i] = namedValue{ + Ordinal: i + 1, + Value: v, + } + } + + ex, err := c.query(query, namedArgs) + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return ex.rows, nil +} + +func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) { + var expected *ExpectedQuery + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if c.ordered { + if expected, ok = next.(*ExpectedQuery); ok { + break + } + next.Unlock() + return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) + } + if qr, ok := next.(*ExpectedQuery); ok { + if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil { + next.Unlock() + continue + } + if err := qr.attemptArgMatch(args); err == nil { + expected = qr + break + } + } + next.Unlock() + } + + if expected == nil { + msg := "call to Query '%s' with args %+v was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg, query, args) + } + + defer expected.Unlock() + + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("Query: %v", err) + } + + if err := expected.argsMatches(args); err != nil { + return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err) + } + + expected.triggered = true + if expected.err != nil { + return expected, expected.err // mocked to return error + } + + if expected.rows == nil { + return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected) + } + return expected, nil +} + +// Exec meets http://golang.org/pkg/database/sql/driver/#Execer +func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) { + namedArgs := make([]namedValue, len(args)) + for i, v := range args { + namedArgs[i] = namedValue{ + Ordinal: i + 1, + Value: v, + } + } + + ex, err := c.exec(query, namedArgs) + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return ex.result, nil +} + +func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { + var expected *ExpectedExec + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if c.ordered { + if expected, ok = next.(*ExpectedExec); ok { + break + } + next.Unlock() + return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) + } + if exec, ok := next.(*ExpectedExec); ok { + if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil { + next.Unlock() + continue + } + + if err := exec.attemptArgMatch(args); err == nil { + expected = exec + break + } + } + next.Unlock() + } + if expected == nil { + msg := "call to ExecQuery '%s' with args %+v was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg, query, args) + } + defer expected.Unlock() + + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("ExecQuery: %v", err) + } + + if err := expected.argsMatches(args); err != nil { + return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err) + } + + expected.triggered = true + if expected.err != nil { + return expected, expected.err // mocked to return error + } + + if expected.result == nil { + return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected) + } + + return expected, nil +} diff --git a/sqlmock_before_go18_test.go b/sqlmock_before_go18_test.go new file mode 100644 index 0000000..3c510a3 --- /dev/null +++ b/sqlmock_before_go18_test.go @@ -0,0 +1,26 @@ +// +build !go1.8 + +package sqlmock + +import ( + "fmt" + "testing" + "time" +) + +func TestSqlmockExpectPingHasNoEffect(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + e := mock.ExpectPing() + + // Methods on the expectation can be called + e.WillDelayFor(time.Hour).WillReturnError(fmt.Errorf("an error")) + + if err = mock.ExpectationsWereMet(); err != nil { + t.Errorf("expected no error to be returned, but got '%s'", err) + } +} diff --git a/sqlmock_go18.go b/sqlmock_go18.go new file mode 100644 index 0000000..f268900 --- /dev/null +++ b/sqlmock_go18.go @@ -0,0 +1,356 @@ +// +build go1.8 + +package sqlmock + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "log" + "time" +) + +// Sqlmock interface for Go 1.8+ +type Sqlmock interface { + // Embed common methods + SqlmockCommon + + // NewRowsWithColumnDefinition allows Rows to be created from a + // sql driver.Value slice with a definition of sql metadata + NewRowsWithColumnDefinition(columns ...*Column) *Rows + + // New Column allows to create a Column + NewColumn(name string) *Column +} + +// ErrCancelled defines an error value, which can be expected in case of +// such cancellation error. +var ErrCancelled = errors.New("canceling query due to user request") + +// Implement the "QueryerContext" interface +func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + ex, err := c.query(query, args) + if ex != nil { + select { + case <-time.After(ex.delay): + if err != nil { + return nil, err + } + return ex.rows, nil + case <-ctx.Done(): + return nil, ErrCancelled + } + } + + return nil, err +} + +// Implement the "ExecerContext" interface +func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + ex, err := c.exec(query, args) + if ex != nil { + select { + case <-time.After(ex.delay): + if err != nil { + return nil, err + } + return ex.result, nil + case <-ctx.Done(): + return nil, ErrCancelled + } + } + + return nil, err +} + +// Implement the "ConnBeginTx" interface +func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + ex, err := c.begin() + if ex != nil { + select { + case <-time.After(ex.delay): + if err != nil { + return nil, err + } + return c, nil + case <-ctx.Done(): + return nil, ErrCancelled + } + } + + return nil, err +} + +// Implement the "ConnPrepareContext" interface +func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + ex, err := c.prepare(query) + if ex != nil { + select { + case <-time.After(ex.delay): + if err != nil { + return nil, err + } + return &statement{c, ex, query}, nil + case <-ctx.Done(): + return nil, ErrCancelled + } + } + + return nil, err +} + +// Implement the "Pinger" interface - the explicit DB driver ping was only added to database/sql in Go 1.8 +func (c *sqlmock) Ping(ctx context.Context) error { + if !c.monitorPings { + return nil + } + + ex, err := c.ping() + if ex != nil { + select { + case <-ctx.Done(): + return ErrCancelled + case <-time.After(ex.delay): + } + } + + return err +} + +func (c *sqlmock) ping() (*ExpectedPing, error) { + var expected *ExpectedPing + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if expected, ok = next.(*ExpectedPing); ok { + break + } + + next.Unlock() + if c.ordered { + return nil, fmt.Errorf("call to database Ping, was not expected, next expectation is: %s", next) + } + } + + if expected == nil { + msg := "call to database Ping was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg) + } + + expected.triggered = true + expected.Unlock() + return expected, expected.err +} + +// Implement the "StmtExecContext" interface +func (stmt *statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + return stmt.conn.ExecContext(ctx, stmt.query, args) +} + +// Implement the "StmtQueryContext" interface +func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + return stmt.conn.QueryContext(ctx, stmt.query, args) +} + +func (c *sqlmock) ExpectPing() *ExpectedPing { + if !c.monitorPings { + log.Println("ExpectPing will have no effect as monitoring pings is disabled. Use MonitorPingsOption to enable.") + return nil + } + e := &ExpectedPing{} + c.expected = append(c.expected, e) + return e +} + +// Query meets http://golang.org/pkg/database/sql/driver/#Queryer +// Deprecated: Drivers should implement QueryerContext instead. +func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) { + namedArgs := make([]driver.NamedValue, len(args)) + for i, v := range args { + namedArgs[i] = driver.NamedValue{ + Ordinal: i + 1, + Value: v, + } + } + + ex, err := c.query(query, namedArgs) + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return ex.rows, nil +} + +func (c *sqlmock) query(query string, args []driver.NamedValue) (*ExpectedQuery, error) { + var expected *ExpectedQuery + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if c.ordered { + if expected, ok = next.(*ExpectedQuery); ok { + break + } + next.Unlock() + return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) + } + if qr, ok := next.(*ExpectedQuery); ok { + if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil { + next.Unlock() + continue + } + if err := qr.attemptArgMatch(args); err == nil { + expected = qr + break + } + } + next.Unlock() + } + + if expected == nil { + msg := "call to Query '%s' with args %+v was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg, query, args) + } + + defer expected.Unlock() + + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("Query: %v", err) + } + + if err := expected.argsMatches(args); err != nil { + return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err) + } + + expected.triggered = true + if expected.err != nil { + return expected, expected.err // mocked to return error + } + + if expected.rows == nil { + return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected) + } + return expected, nil +} + +// Exec meets http://golang.org/pkg/database/sql/driver/#Execer +// Deprecated: Drivers should implement ExecerContext instead. +func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) { + namedArgs := make([]driver.NamedValue, len(args)) + for i, v := range args { + namedArgs[i] = driver.NamedValue{ + Ordinal: i + 1, + Value: v, + } + } + + ex, err := c.exec(query, namedArgs) + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return ex.result, nil +} + +func (c *sqlmock) exec(query string, args []driver.NamedValue) (*ExpectedExec, error) { + var expected *ExpectedExec + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if c.ordered { + if expected, ok = next.(*ExpectedExec); ok { + break + } + next.Unlock() + return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) + } + if exec, ok := next.(*ExpectedExec); ok { + if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil { + next.Unlock() + continue + } + + if err := exec.attemptArgMatch(args); err == nil { + expected = exec + break + } + } + next.Unlock() + } + if expected == nil { + msg := "call to ExecQuery '%s' with args %+v was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg, query, args) + } + defer expected.Unlock() + + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("ExecQuery: %v", err) + } + + if err := expected.argsMatches(args); err != nil { + return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err) + } + + expected.triggered = true + if expected.err != nil { + return expected, expected.err // mocked to return error + } + + if expected.result == nil { + return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected) + } + + return expected, nil +} + +// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions) + +// NewRowsWithColumnDefinition allows Rows to be created from a +// sql driver.Value slice with a definition of sql metadata +func (c *sqlmock) NewRowsWithColumnDefinition(columns ...*Column) *Rows { + r := NewRowsWithColumnDefinition(columns...) + r.converter = c.converter + return r +} + +// NewColumn allows to create a Column that can be enhanced with metadata +// using OfType/Nullable/WithLength/WithPrecisionAndScale methods. +func (c *sqlmock) NewColumn(name string) *Column { + return NewColumn(name) +} diff --git a/sqlmock_go18_19.go b/sqlmock_go18_19.go new file mode 100644 index 0000000..9d81a7f --- /dev/null +++ b/sqlmock_go18_19.go @@ -0,0 +1,11 @@ +// +build go1.8,!go1.9 + +package sqlmock + +import "database/sql/driver" + +// CheckNamedValue meets https://golang.org/pkg/database/sql/driver/#NamedValueChecker +func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = c.converter.ConvertValue(nv.Value) + return err +} diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go new file mode 100644 index 0000000..223e076 --- /dev/null +++ b/sqlmock_go18_test.go @@ -0,0 +1,641 @@ +// +build go1.8 + +package sqlmock + +import ( + "context" + "database/sql" + "errors" + "testing" + "time" +) + +func TestContextExecCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("DELETE FROM users"). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.ExecContext(ctx, "DELETE FROM users") + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.ExecContext(ctx, "DELETE FROM users") + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestPreparedStatementContextExecCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("DELETE FROM users"). + ExpectExec(). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + stmt, err := db.Prepare("DELETE FROM users") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + _, err = stmt.ExecContext(ctx) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = stmt.ExecContext(ctx) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestContextExecWithNamedArg(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("DELETE FROM users"). + WithArgs(sql.Named("id", 5)). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5)) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5)) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestContextExec(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("DELETE FROM users"). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + res, err := db.ExecContext(ctx, "DELETE FROM users") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + affected, err := res.RowsAffected() + if affected != 1 { + t.Errorf("expected affected rows 1, but got %v", affected) + } + + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestContextQueryCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillDelayFor(time.Second). + WillReturnRows(rs) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestPreparedStatementContextQueryCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?"). + ExpectQuery(). + WithArgs(5). + WillDelayFor(time.Second). + WillReturnRows(rs) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + _, err = stmt.QueryContext(ctx, 5) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = stmt.QueryContext(ctx, 5) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestContextQuery(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id ="). + WithArgs(sql.Named("id", 5)). + WillDelayFor(time.Millisecond * 3). + WillReturnRows(rs) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + rows, err := db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = :id", sql.Named("id", 5)) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if !rows.Next() { + t.Error("expected one row, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestContextBeginCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin().WillDelayFor(time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.BeginTx(ctx, nil) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.BeginTx(ctx, nil) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestContextBegin(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin().WillDelayFor(time.Millisecond * 3) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if tx == nil { + t.Error("expected tx, but there was nil") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestContextPrepareCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("SELECT").WillDelayFor(time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.PrepareContext(ctx, "SELECT") + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.PrepareContext(ctx, "SELECT") + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestContextPrepare(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("SELECT").WillDelayFor(time.Millisecond * 3) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + stmt, err := db.PrepareContext(ctx, "SELECT") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if stmt == nil { + t.Error("expected stmt, but there was nil") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestContextExecErrorDelay(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // test that return of error is delayed + var delay time.Duration + delay = 100 * time.Millisecond + mock.ExpectExec("^INSERT INTO articles"). + WillReturnError(errors.New("slow fail")). + WillDelayFor(delay) + + start := time.Now() + res, err := db.ExecContext(context.Background(), "INSERT INTO articles (title) VALUES (?)", "hello") + stop := time.Now() + + if res != nil { + t.Errorf("result was not expected, was expecting nil") + } + + if err == nil { + t.Errorf("error was expected, was not expecting nil") + } + + if err.Error() != "slow fail" { + t.Errorf("error '%s' was not expected, was expecting '%s'", err.Error(), "slow fail") + } + + elapsed := stop.Sub(start) + if elapsed < delay { + t.Errorf("expecting a delay of %v before error, actual delay was %v", delay, elapsed) + } + + // also test that return of error is not delayed + mock.ExpectExec("^INSERT INTO articles").WillReturnError(errors.New("fast fail")) + + start = time.Now() + db.ExecContext(context.Background(), "INSERT INTO articles (title) VALUES (?)", "hello") + stop = time.Now() + + elapsed = stop.Sub(start) + if elapsed > delay { + t.Errorf("expecting a delay of less than %v before error, actual delay was %v", delay, elapsed) + } +} + +// TestMonitorPingsDisabled verifies backwards-compatibility with behaviour of the library in which +// calls to Ping are not mocked out. It verifies this persists when the user does not enable the new +// behaviour. +func TestMonitorPingsDisabled(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // When monitoring of pings is not enabled in the mock, calling Ping should have no effect. + err = db.Ping() + if err != nil { + t.Errorf("monitoring of pings is not enabled so did not expect error from Ping, got '%s'", err) + } + + // Calling ExpectPing should also not register any expectations in the mock. The return from + // ExpectPing should be nil. + expectation := mock.ExpectPing() + if expectation != nil { + t.Errorf("expected ExpectPing to return a nil pointer when monitoring of pings is not enabled") + } + + err = mock.ExpectationsWereMet() + if err != nil { + t.Errorf("monitoring of pings is not enabled so ExpectPing should not register an expectation, got '%s'", err) + } +} + +func TestPingExpectations(t *testing.T) { + t.Parallel() + db, mock, err := New(MonitorPingsOption(true)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPing() + if err := db.Ping(); err != nil { + t.Fatal(err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestPingExpectationsErrorDelay(t *testing.T) { + t.Parallel() + db, mock, err := New(MonitorPingsOption(true)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + var delay time.Duration + delay = 100 * time.Millisecond + mock.ExpectPing(). + WillReturnError(errors.New("slow fail")). + WillDelayFor(delay) + + start := time.Now() + err = db.Ping() + stop := time.Now() + + if err == nil { + t.Errorf("result was not expected, was not expecting nil error") + } + + if err.Error() != "slow fail" { + t.Errorf("error '%s' was not expected, was expecting '%s'", err.Error(), "slow fail") + } + + elapsed := stop.Sub(start) + if elapsed < delay { + t.Errorf("expecting a delay of %v before error, actual delay was %v", delay, elapsed) + } + + mock.ExpectPing().WillReturnError(errors.New("fast fail")) + + start = time.Now() + db.Ping() + stop = time.Now() + + elapsed = stop.Sub(start) + if elapsed > delay { + t.Errorf("expecting a delay of less than %v before error, actual delay was %v", delay, elapsed) + } +} + +func TestPingExpectationsMissingPing(t *testing.T) { + t.Parallel() + db, mock, err := New(MonitorPingsOption(true)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPing() + + if err = mock.ExpectationsWereMet(); err == nil { + t.Fatalf("was expecting an error, but there wasn't one") + } +} + +func TestPingExpectationsUnexpectedPing(t *testing.T) { + t.Parallel() + db, _, err := New(MonitorPingsOption(true)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + if err = db.Ping(); err == nil { + t.Fatalf("was expecting an error, but there wasn't any") + } +} + +func TestPingOrderedWrongOrder(t *testing.T) { + t.Parallel() + db, mock, err := New(MonitorPingsOption(true)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin() + mock.ExpectPing() + mock.MatchExpectationsInOrder(true) + + if err = db.Ping(); err == nil { + t.Fatalf("was expecting an error, but there wasn't any") + } +} + +func TestPingExpectationsContextTimeout(t *testing.T) { + t.Parallel() + db, mock, err := New(MonitorPingsOption(true)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPing().WillDelayFor(time.Hour) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + doneCh := make(chan struct{}) + go func() { + err = db.PingContext(ctx) + close(doneCh) + }() + + select { + case <-doneCh: + if err != ErrCancelled { + t.Errorf("expected error '%s' to be returned from Ping, but got '%s'", ErrCancelled, err) + } + case <-time.After(time.Second): + t.Errorf("expected Ping to return after context timeout, but it did not in a timely fashion") + } +} diff --git a/sqlmock_go19.go b/sqlmock_go19.go new file mode 100644 index 0000000..c0f2424 --- /dev/null +++ b/sqlmock_go19.go @@ -0,0 +1,19 @@ +// +build go1.9 + +package sqlmock + +import ( + "database/sql" + "database/sql/driver" +) + +// CheckNamedValue meets https://golang.org/pkg/database/sql/driver/#NamedValueChecker +func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) { + switch nv.Value.(type) { + case sql.Out: + return nil + default: + nv.Value, err = c.converter.ConvertValue(nv.Value) + return err + } +} diff --git a/sqlmock_go19_test.go b/sqlmock_go19_test.go new file mode 100644 index 0000000..910d704 --- /dev/null +++ b/sqlmock_go19_test.go @@ -0,0 +1,70 @@ +// +build go1.9 + +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "errors" + "testing" +) + +func TestStatementTX(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + prep := mock.ExpectPrepare("SELECT") + mock.ExpectBegin() + + prep.ExpectQuery().WithArgs(1).WillReturnError(errors.New("fast fail")) + + stmt, err := db.Prepare("SELECT title, body FROM articles WHERE id = ?") + if err != nil { + t.Fatalf("unexpected error on prepare: %v", err) + } + + tx, err := db.Begin() + if err != nil { + t.Fatalf("unexpected error on begin: %v", err) + } + + // upgrade connection for statement + txStmt := tx.Stmt(stmt) + _, err = txStmt.Query(1) + if err == nil || err.Error() != "fast fail" { + t.Fatalf("unexpected result: %v", err) + } +} + +func Test_sqlmock_CheckNamedValue(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + tests := []struct { + name string + arg *driver.NamedValue + wantErr bool + }{ + { + arg: &driver.NamedValue{Name: "test", Value: "test"}, + wantErr: false, + }, + { + arg: &driver.NamedValue{Name: "test", Value: sql.Out{}}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := mock.(*sqlmock).CheckNamedValue(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("CheckNamedValue() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/sqlmock_test.go b/sqlmock_test.go new file mode 100644 index 0000000..ee6b516 --- /dev/null +++ b/sqlmock_test.go @@ -0,0 +1,1342 @@ +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" + "sync" + "testing" + "time" +) + +func cancelOrder(db *sql.DB, orderID int) error { + tx, _ := db.Begin() + _, _ = tx.Query("SELECT * FROM orders {0} FOR UPDATE", orderID) + err := tx.Rollback() + if err != nil { + return err + } + return nil +} + +func Example() { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + // columns to be used for result + columns := []string{"id", "status"} + // expect transaction begin + mock.ExpectBegin() + // expect query to fetch order, match it with regexp + mock.ExpectQuery("SELECT (.+) FROM orders (.+) FOR UPDATE"). + WithArgs(1). + WillReturnRows(NewRows(columns).AddRow(1, 1)) + // expect transaction rollback, since order status is "cancelled" + mock.ExpectRollback() + + // run the cancel order function + someOrderID := 1 + // call a function which executes expected database operations + err = cancelOrder(db, someOrderID) + if err != nil { + fmt.Printf("unexpected error: %s", err) + return + } + + // ensure all expectations have been met + if err = mock.ExpectationsWereMet(); err != nil { + fmt.Printf("unmet expectation error: %s", err) + } + // Output: +} + +func TestIssue14EscapeSQL(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + mock.ExpectExec("INSERT INTO mytable\\(a, b\\)"). + WithArgs("A", "B"). + WillReturnResult(NewResult(1, 1)) + + _, err = db.Exec("INSERT INTO mytable(a, b) VALUES (?, ?)", "A", "B") + if err != nil { + t.Errorf("error '%s' was not expected, while inserting a row", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +// test the case when db is not triggered and expectations +// are not asserted on close +func TestIssue4(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectQuery("some sql query which will not be called"). + WillReturnRows(NewRows([]string{"id"})) + + if err := mock.ExpectationsWereMet(); err == nil { + t.Errorf("was expecting an error since query was not triggered") + } +} + +func TestMockQuery(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillReturnRows(rs) + + rows, err := db.Query("SELECT (.+) FROM articles WHERE id = ?", 5) + if err != nil { + t.Errorf("error '%s' was not expected while retrieving mock rows", err) + } + + defer func() { + if er := rows.Close(); er != nil { + t.Error("unexpected error while trying to close rows") + } + }() + + if !rows.Next() { + t.Error("it must have had one row as result, but got empty result set instead") + } + + var id int + var title string + + err = rows.Scan(&id, &title) + if err != nil { + t.Errorf("error '%s' was not expected while trying to scan row", err) + } + + if id != 5 { + t.Errorf("expected mocked id to be 5, but got %d instead", id) + } + + if title != "hello world" { + t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestMockQueryTypes(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + columns := []string{"id", "timestamp", "sold"} + + timestamp := time.Now() + rs := NewRows(columns) + rs.AddRow(5, timestamp, true) + + mock.ExpectQuery("SELECT (.+) FROM sales WHERE id = ?"). + WithArgs(5). + WillReturnRows(rs) + + rows, err := db.Query("SELECT (.+) FROM sales WHERE id = ?", 5) + if err != nil { + t.Errorf("error '%s' was not expected while retrieving mock rows", err) + } + defer func() { + if er := rows.Close(); er != nil { + t.Error("unexpected error while trying to close rows") + } + }() + if !rows.Next() { + t.Error("it must have had one row as result, but got empty result set instead") + } + + var id int + var time time.Time + var sold bool + + err = rows.Scan(&id, &time, &sold) + if err != nil { + t.Errorf("error '%s' was not expected while trying to scan row", err) + } + + if id != 5 { + t.Errorf("expected mocked id to be 5, but got %d instead", id) + } + + if time != timestamp { + t.Errorf("expected mocked time to be %s, but got '%s' instead", timestamp, time) + } + + if sold != true { + t.Errorf("expected mocked boolean to be true, but got %v instead", sold) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestTransactionExpectations(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // begin and commit + mock.ExpectBegin() + mock.ExpectCommit() + + tx, err := db.Begin() + if err != nil { + t.Errorf("an error '%s' was not expected when beginning a transaction", err) + } + + err = tx.Commit() + if err != nil { + t.Errorf("an error '%s' was not expected when committing a transaction", err) + } + + // begin and rollback + mock.ExpectBegin() + mock.ExpectRollback() + + tx, err = db.Begin() + if err != nil { + t.Errorf("an error '%s' was not expected when beginning a transaction", err) + } + + err = tx.Rollback() + if err != nil { + t.Errorf("an error '%s' was not expected when rolling back a transaction", err) + } + + // begin with an error + mock.ExpectBegin().WillReturnError(fmt.Errorf("some err")) + + tx, err = db.Begin() + if err == nil { + t.Error("an error was expected when beginning a transaction, but got none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestPrepareExpectations(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?") + + stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?") + if err != nil { + t.Errorf("error '%s' was not expected while creating a prepared statement", err) + } + if stmt == nil { + t.Errorf("stmt was expected while creating a prepared statement") + } + + // expect something else, w/o ExpectPrepare() + var id int + var title string + rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillReturnRows(rs) + + err = stmt.QueryRow(5).Scan(&id, &title) + if err != nil { + t.Errorf("error '%s' was not expected while retrieving mock rows", err) + } + + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?"). + WillReturnError(fmt.Errorf("Some DB error occurred")) + + stmt, err = db.Prepare("SELECT id FROM articles WHERE id = ?") + if err == nil { + t.Error("error was expected while creating a prepared statement") + } + if stmt != nil { + t.Errorf("stmt was not expected while creating a prepared statement returning error") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestPreparedQueryExecutions(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?") + + rs1 := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillReturnRows(rs1) + + rs2 := NewRows([]string{"id", "title"}).FromCSVString("2,whoop") + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(2). + WillReturnRows(rs2) + + stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?") + if err != nil { + t.Errorf("error '%s' was not expected while creating a prepared statement", err) + } + + var id int + var title string + err = stmt.QueryRow(5).Scan(&id, &title) + if err != nil { + t.Errorf("error '%s' was not expected querying row from statement and scanning", err) + } + + if id != 5 { + t.Errorf("expected mocked id to be 5, but got %d instead", id) + } + + if title != "hello world" { + t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title) + } + + err = stmt.QueryRow(2).Scan(&id, &title) + if err != nil { + t.Errorf("error '%s' was not expected querying row from statement and scanning", err) + } + + if id != 2 { + t.Errorf("expected mocked id to be 2, but got %d instead", id) + } + + if title != "whoop" { + t.Errorf("expected mocked title to be 'whoop', but got '%s' instead", title) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestUnorderedPreparedQueryExecutions(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.MatchExpectationsInOrder(false) + + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?"). + ExpectQuery(). + WithArgs(5). + WillReturnRows(NewRows([]string{"id", "title"}).FromCSVString("5,The quick brown fox")) + mock.ExpectPrepare("SELECT (.+) FROM authors WHERE id = ?"). + ExpectQuery(). + WithArgs(1). + WillReturnRows(NewRows([]string{"id", "title"}).FromCSVString("1,Betty B.")) + + var id int + var name string + + stmt, err := db.Prepare("SELECT id, name FROM authors WHERE id = ?") + if err != nil { + t.Errorf("error '%s' was not expected while creating a prepared statement", err) + } + + err = stmt.QueryRow(1).Scan(&id, &name) + if err != nil { + t.Errorf("error '%s' was not expected querying row from statement and scanning", err) + } + + if name != "Betty B." { + t.Errorf("expected mocked name to be 'Betty B.', but got '%s' instead", name) + } +} + +func TestUnexpectedOperations(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?") + stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?") + if err != nil { + t.Errorf("error '%s' was not expected while creating a prepared statement", err) + } + + var id int + var title string + + err = stmt.QueryRow(5).Scan(&id, &title) + if err == nil { + t.Error("error was expected querying row, since there was no such expectation") + } + + mock.ExpectRollback() + + if err := mock.ExpectationsWereMet(); err == nil { + t.Errorf("was expecting an error since query was not triggered") + } +} + +func TestWrongExpectations(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin() + + rs1 := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillReturnRows(rs1) + + mock.ExpectCommit().WillReturnError(fmt.Errorf("deadlock occurred")) + mock.ExpectRollback() // won't be triggered + + var id int + var title string + + err = db.QueryRow("SELECT id, title FROM articles WHERE id = ? FOR UPDATE", 5).Scan(&id, &title) + if err == nil { + t.Error("error was expected while querying row, since there begin transaction expectation is not fulfilled") + } + + // lets go around and start transaction + tx, err := db.Begin() + if err != nil { + t.Errorf("an error '%s' was not expected when beginning a transaction", err) + } + + err = db.QueryRow("SELECT id, title FROM articles WHERE id = ? FOR UPDATE", 5).Scan(&id, &title) + if err != nil { + t.Errorf("error '%s' was not expected while querying row, since transaction was started", err) + } + + err = tx.Commit() + if err == nil { + t.Error("a deadlock error was expected when committing a transaction", err) + } + + if err := mock.ExpectationsWereMet(); err == nil { + t.Errorf("was expecting an error since query was not triggered") + } +} + +func TestExecExpectations(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + result := NewResult(1, 1) + mock.ExpectExec("^INSERT INTO articles"). + WithArgs("hello"). + WillReturnResult(result) + + res, err := db.Exec("INSERT INTO articles (title) VALUES (?)", "hello") + if err != nil { + t.Errorf("error '%s' was not expected, while inserting a row", err) + } + + id, err := res.LastInsertId() + if err != nil { + t.Errorf("error '%s' was not expected, while getting a last insert id", err) + } + + affected, err := res.RowsAffected() + if err != nil { + t.Errorf("error '%s' was not expected, while getting affected rows", err) + } + + if id != 1 { + t.Errorf("expected last insert id to be 1, but got %d instead", id) + } + + if affected != 1 { + t.Errorf("expected affected rows to be 1, but got %d instead", affected) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestRowBuilderAndNilTypes(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "active", "created", "status"}). + AddRow(1, true, time.Now(), 5). + AddRow(2, false, nil, nil) + + mock.ExpectQuery("SELECT (.+) FROM sales").WillReturnRows(rs) + + rows, err := db.Query("SELECT * FROM sales") + if err != nil { + t.Errorf("error '%s' was not expected while retrieving mock rows", err) + } + defer func() { + if er := rows.Close(); er != nil { + t.Error("Unexpected error while trying to close rows") + } + }() + + // NullTime and NullInt are used from stubs_test.go + var ( + id int + active bool + created NullTime + status NullInt + ) + + if !rows.Next() { + t.Error("it must have had row in rows, but got empty result set instead") + } + + err = rows.Scan(&id, &active, &created, &status) + if err != nil { + t.Errorf("error '%s' was not expected while trying to scan row", err) + } + + if id != 1 { + t.Errorf("expected mocked id to be 1, but got %d instead", id) + } + + if !active { + t.Errorf("expected 'active' to be 'true', but got '%v' instead", active) + } + + if !created.Valid { + t.Errorf("expected 'created' to be valid, but it %+v is not", created) + } + + if !status.Valid { + t.Errorf("expected 'status' to be valid, but it %+v is not", status) + } + + if status.Integer != 5 { + t.Errorf("expected 'status' to be '5', but got '%d'", status.Integer) + } + + // test second row + if !rows.Next() { + t.Error("it must have had row in rows, but got empty result set instead") + } + + err = rows.Scan(&id, &active, &created, &status) + if err != nil { + t.Errorf("error '%s' was not expected while trying to scan row", err) + } + + if id != 2 { + t.Errorf("expected mocked id to be 2, but got %d instead", id) + } + + if active { + t.Errorf("expected 'active' to be 'false', but got '%v' instead", active) + } + + if created.Valid { + t.Errorf("expected 'created' to be invalid, but it %+v is not", created) + } + + if status.Valid { + t.Errorf("expected 'status' to be invalid, but it %+v is not", status) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestArgumentReflectValueTypeError(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id"}).AddRow(1) + + mock.ExpectQuery("SELECT (.+) FROM sales").WithArgs(5.5).WillReturnRows(rs) + + _, err = db.Query("SELECT * FROM sales WHERE x = ?", 5) + if err == nil { + t.Error("expected error, but got none") + } +} + +func TestGoroutineExecutionWithUnorderedExpectationMatching(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // note this line is important for unordered expectation matching + mock.MatchExpectationsInOrder(false) + + result := NewResult(1, 1) + + mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result) + mock.ExpectExec("^UPDATE two").WithArgs("one", "two").WillReturnResult(result) + mock.ExpectExec("^UPDATE three").WithArgs("one", "two", "three").WillReturnResult(result) + + var wg sync.WaitGroup + queries := map[string][]interface{}{ + "one": {"one"}, + "two": {"one", "two"}, + "three": {"one", "two", "three"}, + } + + wg.Add(len(queries)) + for table, args := range queries { + go func(tbl string, a []interface{}) { + if _, err := db.Exec("UPDATE "+tbl, a...); err != nil { + t.Errorf("error was not expected: %s", err) + } + wg.Done() + }(table, args) + } + + wg.Wait() + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func ExampleSqlmock_goroutines() { + db, mock, err := New() + if err != nil { + fmt.Println("failed to open sqlmock database:", err) + } + defer db.Close() + + // note this line is important for unordered expectation matching + mock.MatchExpectationsInOrder(false) + + result := NewResult(1, 1) + + mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result) + mock.ExpectExec("^UPDATE two").WithArgs("one", "two").WillReturnResult(result) + mock.ExpectExec("^UPDATE three").WithArgs("one", "two", "three").WillReturnResult(result) + + var wg sync.WaitGroup + queries := map[string][]interface{}{ + "one": {"one"}, + "two": {"one", "two"}, + "three": {"one", "two", "three"}, + } + + wg.Add(len(queries)) + for table, args := range queries { + go func(tbl string, a []interface{}) { + if _, err := db.Exec("UPDATE "+tbl, a...); err != nil { + fmt.Println("error was not expected:", err) + } + wg.Done() + }(table, args) + } + + wg.Wait() + + if err := mock.ExpectationsWereMet(); err != nil { + fmt.Println("there were unfulfilled expectations:", err) + } + // Output: +} + +// False Positive - passes despite mismatched Exec +// see #37 issue +func TestRunExecsWithOrderedShouldNotMeetAllExpectations(t *testing.T) { + db, dbmock, _ := New() + dbmock.ExpectExec("THE FIRST EXEC") + dbmock.ExpectExec("THE SECOND EXEC") + + _, _ = db.Exec("THE FIRST EXEC") + _, _ = db.Exec("THE WRONG EXEC") + + err := dbmock.ExpectationsWereMet() + if err == nil { + t.Fatal("was expecting an error, but there wasn't any") + } +} + +// False Positive - passes despite mismatched Exec +// see #37 issue +func TestRunQueriesWithOrderedShouldNotMeetAllExpectations(t *testing.T) { + db, dbmock, _ := New() + dbmock.ExpectQuery("THE FIRST QUERY") + dbmock.ExpectQuery("THE SECOND QUERY") + + _, _ = db.Query("THE FIRST QUERY") + _, _ = db.Query("THE WRONG QUERY") + + err := dbmock.ExpectationsWereMet() + if err == nil { + t.Fatal("was expecting an error, but there wasn't any") + } +} + +func TestRunExecsWithExpectedErrorMeetsExpectations(t *testing.T) { + db, dbmock, _ := New() + dbmock.ExpectExec("THE FIRST EXEC").WillReturnError(fmt.Errorf("big bad bug")) + dbmock.ExpectExec("THE SECOND EXEC").WillReturnResult(NewResult(0, 0)) + + _, _ = db.Exec("THE FIRST EXEC") + _, _ = db.Exec("THE SECOND EXEC") + + err := dbmock.ExpectationsWereMet() + if err != nil { + t.Fatalf("all expectations should be met: %s", err) + } +} + +func TestRunQueryWithExpectedErrorMeetsExpectations(t *testing.T) { + db, dbmock, _ := New() + dbmock.ExpectQuery("THE FIRST QUERY").WillReturnError(fmt.Errorf("big bad bug")) + dbmock.ExpectQuery("THE SECOND QUERY").WillReturnRows(NewRows([]string{"col"}).AddRow(1)) + + _, _ = db.Query("THE FIRST QUERY") + _, _ = db.Query("THE SECOND QUERY") + + err := dbmock.ExpectationsWereMet() + if err != nil { + t.Fatalf("all expectations should be met: %s", err) + } +} + +func TestEmptyRowSet(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}) + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillReturnRows(rs) + + rows, err := db.Query("SELECT (.+) FROM articles WHERE id = ?", 5) + if err != nil { + t.Errorf("error '%s' was not expected while retrieving mock rows", err) + } + + defer func() { + if er := rows.Close(); er != nil { + t.Error("unexpected error while trying to close rows") + } + }() + + if rows.Next() { + t.Error("expected no rows but got one") + } + + err = mock.ExpectationsWereMet() + if err != nil { + t.Fatalf("all expectations should be met: %s", err) + } +} + +// Based on issue #50 +func TestPrepareExpectationNotFulfilled(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("^BADSELECT$") + + if _, err := db.Prepare("SELECT"); err == nil { + t.Fatal("prepare should not match expected query string") + } + + if err := mock.ExpectationsWereMet(); err == nil { + t.Errorf("was expecting an error, since prepared statement query does not match, but there was none") + } +} + +func TestRollbackThrow(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + // columns to be used for result + columns := []string{"id", "status"} + // expect transaction begin + mock.ExpectBegin() + // expect query to fetch order, match it with regexp + mock.ExpectQuery("SELECT (.+) FROM orders (.+) FOR UPDATE"). + WithArgs(1). + WillReturnRows(NewRows(columns).AddRow(1, 1)) + // expect transaction rollback, since order status is "cancelled" + mock.ExpectRollback().WillReturnError(fmt.Errorf("rollback failed")) + + // run the cancel order function + someOrderID := 1 + // call a function which executes expected database operations + err = cancelOrder(db, someOrderID) + if err == nil { + t.Error("an error was expected when rolling back transaction, but got none") + } + + // ensure all expectations have been met + if err = mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet expectation error: %s", err) + } + // Output: +} + +func TestUnexpectedBegin(t *testing.T) { + // Open new mock database + db, _, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + if _, err := db.Begin(); err == nil { + t.Error("an error was expected when calling begin, but got none") + } +} + +func TestUnexpectedExec(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + mock.ExpectBegin() + db.Begin() + if _, err := db.Exec("SELECT 1"); err == nil { + t.Error("an error was expected when calling exec, but got none") + } +} + +func TestUnexpectedCommit(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + mock.ExpectBegin() + tx, _ := db.Begin() + if err := tx.Commit(); err == nil { + t.Error("an error was expected when calling commit, but got none") + } +} + +func TestUnexpectedCommitOrder(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + mock.ExpectBegin() + mock.ExpectRollback().WillReturnError(fmt.Errorf("Rollback failed")) + tx, _ := db.Begin() + if err := tx.Commit(); err == nil { + t.Error("an error was expected when calling commit, but got none") + } +} + +func TestExpectedCommitOrder(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + mock.ExpectCommit().WillReturnError(fmt.Errorf("Commit failed")) + if _, err := db.Begin(); err == nil { + t.Error("an error was expected when calling begin, but got none") + } +} + +func TestUnexpectedRollback(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + mock.ExpectBegin() + tx, _ := db.Begin() + if err := tx.Rollback(); err == nil { + t.Error("an error was expected when calling rollback, but got none") + } +} + +func TestUnexpectedRollbackOrder(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + mock.ExpectBegin() + + tx, _ := db.Begin() + if err := tx.Rollback(); err == nil { + t.Error("an error was expected when calling rollback, but got none") + } +} + +func TestPrepareExec(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + defer db.Close() + mock.ExpectBegin() + ep := mock.ExpectPrepare("INSERT INTO ORDERS\\(ID, STATUS\\) VALUES \\(\\?, \\?\\)") + for i := 0; i < 3; i++ { + ep.ExpectExec().WillReturnResult(NewResult(1, 1)) + } + mock.ExpectCommit() + tx, _ := db.Begin() + stmt, err := tx.Prepare("INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + for i := 0; i < 3; i++ { + _, err := stmt.Exec(i, "Hello"+strconv.Itoa(i)) + if err != nil { + t.Fatal(err) + } + } + tx.Commit() + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestPrepareQuery(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + defer db.Close() + mock.ExpectBegin() + ep := mock.ExpectPrepare("SELECT ID, STATUS FROM ORDERS WHERE ID = \\?") + ep.ExpectQuery().WithArgs(101).WillReturnRows(NewRows([]string{"ID", "STATUS"}).AddRow(101, "Hello")) + mock.ExpectCommit() + tx, _ := db.Begin() + stmt, err := tx.Prepare("SELECT ID, STATUS FROM ORDERS WHERE ID = ?") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + rows, err := stmt.Query(101) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var ( + id int + status string + ) + if rows.Scan(&id, &status); id != 101 || status != "Hello" { + t.Fatal("wrong query results") + } + + } + tx.Commit() + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestExpectedCloseError(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + mock.ExpectClose().WillReturnError(fmt.Errorf("Close failed")) + if err := db.Close(); err == nil { + t.Error("an error was expected when calling close, but got none") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestExpectedCloseOrder(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + defer db.Close() + mock.ExpectClose().WillReturnError(fmt.Errorf("Close failed")) + db.Begin() + if err := mock.ExpectationsWereMet(); err == nil { + t.Error("expected error on ExpectationsWereMet") + } +} + +func TestExpectedBeginOrder(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + mock.ExpectBegin().WillReturnError(fmt.Errorf("Begin failed")) + if err := db.Close(); err == nil { + t.Error("an error was expected when calling close, but got none") + } +} + +func TestPreparedStatementCloseExpectation(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + defer db.Close() + + ep := mock.ExpectPrepare("INSERT INTO ORDERS").WillBeClosed() + ep.ExpectExec().WillReturnResult(NewResult(1, 1)) + + stmt, err := db.Prepare("INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)") + if err != nil { + t.Fatal(err) + } + + if _, err := stmt.Exec(1, "Hello"); err != nil { + t.Fatal(err) + } + + if err := stmt.Close(); err != nil { + t.Fatal(err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestExecExpectationErrorDelay(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // test that return of error is delayed + var delay time.Duration + delay = 100 * time.Millisecond + mock.ExpectExec("^INSERT INTO articles"). + WillReturnError(errors.New("slow fail")). + WillDelayFor(delay) + + start := time.Now() + res, err := db.Exec("INSERT INTO articles (title) VALUES (?)", "hello") + stop := time.Now() + + if res != nil { + t.Errorf("result was not expected, was expecting nil") + } + + if err == nil { + t.Errorf("error was expected, was not expecting nil") + } + + if err.Error() != "slow fail" { + t.Errorf("error '%s' was not expected, was expecting '%s'", err.Error(), "slow fail") + } + + elapsed := stop.Sub(start) + if elapsed < delay { + t.Errorf("expecting a delay of %v before error, actual delay was %v", delay, elapsed) + } + + // also test that return of error is not delayed + mock.ExpectExec("^INSERT INTO articles").WillReturnError(errors.New("fast fail")) + + start = time.Now() + db.Exec("INSERT INTO articles (title) VALUES (?)", "hello") + stop = time.Now() + + elapsed = stop.Sub(start) + if elapsed > delay { + t.Errorf("expecting a delay of less than %v before error, actual delay was %v", delay, elapsed) + } +} + +func TestOptionsFail(t *testing.T) { + t.Parallel() + expected := errors.New("failing option") + option := func(*sqlmock) error { + return expected + } + db, _, err := New(option) + defer db.Close() + if err == nil { + t.Errorf("missing expecting error '%s' when opening a stub database connection", expected) + } +} + +func TestNewRows(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + columns := []string{"col1", "col2"} + + r := mock.NewRows(columns) + if len(r.cols) != len(columns) || r.cols[0] != columns[0] || r.cols[1] != columns[1] { + t.Errorf("expecting to create a row with columns %v, actual colmns are %v", r.cols, columns) + } +} + +// This is actually a test of ExpectationsWereMet. Without a lock around e.fulfilled() inside +// ExpectationWereMet, the race detector complains if e.triggered is being read while it is also +// being written by the query running in another goroutine. +func TestQueryWithTimeout(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WillDelayFor(15 * time.Millisecond). // Query will take longer than timeout + WithArgs(5). + WillReturnRows(rs) + + _, err = queryWithTimeout(10*time.Millisecond, db, "SELECT (.+) FROM articles WHERE id = ?", 5) + if err == nil { + t.Errorf("expecting query to time out") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func queryWithTimeout(t time.Duration, db *sql.DB, query string, args ...interface{}) (*sql.Rows, error) { + rowsChan := make(chan *sql.Rows, 1) + errChan := make(chan error, 1) + + go func() { + rows, err := db.Query(query, args...) + if err != nil { + errChan <- err + return + } + rowsChan <- rows + }() + + select { + case rows := <-rowsChan: + return rows, nil + case err := <-errChan: + return nil, err + case <-time.After(t): + return nil, fmt.Errorf("query timed out after %v", t) + } +} + +func Test_sqlmock_Prepare_and_Exec(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + query := "SELECT name, email FROM users WHERE name = ?" + + mock.ExpectPrepare("SELECT (.+) FROM users WHERE (.+)") + expected := NewResult(1, 1) + mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)"). + WillReturnResult(expected) + expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com") + mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows) + + got, err := mock.(*sqlmock).Prepare(query) + if err != nil { + t.Error(err) + return + } + if got == nil { + t.Error("Prepare () stmt must not be nil") + return + } + result, err := got.Exec([]driver.Value{"test"}) + if err != nil { + t.Error(err) + return + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("Results are not equal. Expected: %v, Actual: %v", expected, result) + return + } + rows, err := got.Query([]driver.Value{"test"}) + if err != nil { + t.Error(err) + return + } + defer rows.Close() +} + +type failArgument struct{} + +func (f failArgument) Match(_ driver.Value) bool { + return false +} + +func Test_sqlmock_Exec(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin() + _, err = mock.(*sqlmock).Exec("", []driver.Value{}) + if err == nil { + t.Errorf("error expected") + return + } + + expected := NewResult(1, 1) + mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)"). + WillReturnResult(expected). + WithArgs("test") + + matchErr := errors.New("matcher sqlmock.failArgument could not match 0 argument driver.NamedValue - {Name: Ordinal:1 Value:{}}") + mock.ExpectExec("SELECT (.+) FROM animals WHERE (.+)"). + WillReturnError(matchErr). + WithArgs(failArgument{}) + + mock.ExpectExec("").WithArgs(failArgument{}) + + mock.(*sqlmock).expected = mock.(*sqlmock).expected[1:] + query := "SELECT name, email FROM users WHERE name = ?" + result, err := mock.(*sqlmock).Exec(query, []driver.Value{"test"}) + if err != nil { + t.Error(err) + return + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("Results are not equal. Expected: %v, Actual: %v", expected, result) + return + } + + failQuery := "SELECT name, sex FROM animals WHERE sex = ?" + _, err = mock.(*sqlmock).Exec(failQuery, []driver.Value{failArgument{}}) + if err == nil { + t.Errorf("error expected") + return + } + mock.(*sqlmock).ordered = false + _, err = mock.(*sqlmock).Exec("", []driver.Value{failArgument{}}) + if err == nil { + t.Errorf("error expected") + return + } +} + +func Test_sqlmock_Query(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com") + mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows) + query := "SELECT name, email FROM users WHERE name = ?" + rows, err := mock.(*sqlmock).Query(query, []driver.Value{"test"}) + if err != nil { + t.Error(err) + return + } + defer rows.Close() + _, err = mock.(*sqlmock).Query(query, []driver.Value{failArgument{}}) + if err == nil { + t.Errorf("error expected") + return + } +} diff --git a/statement.go b/statement.go new file mode 100644 index 0000000..852b8f3 --- /dev/null +++ b/statement.go @@ -0,0 +1,16 @@ +package sqlmock + +type statement struct { + conn *sqlmock + ex *ExpectedPrepare + query string +} + +func (stmt *statement) Close() error { + stmt.ex.wasClosed = true + return stmt.ex.closeErr +} + +func (stmt *statement) NumInput() int { + return -1 +} diff --git a/statement_before_go18.go b/statement_before_go18.go new file mode 100644 index 0000000..e2cac2b --- /dev/null +++ b/statement_before_go18.go @@ -0,0 +1,17 @@ +// +build !go1.8 + +package sqlmock + +import ( + "database/sql/driver" +) + +// Deprecated: Drivers should implement ExecerContext instead. +func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) { + return stmt.conn.Exec(stmt.query, args) +} + +// Deprecated: Drivers should implement StmtQueryContext instead (or additionally). +func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) { + return stmt.conn.Query(stmt.query, args) +} diff --git a/statement_go18.go b/statement_go18.go new file mode 100644 index 0000000..e083051 --- /dev/null +++ b/statement_go18.go @@ -0,0 +1,26 @@ +// +build go1.8 + +package sqlmock + +import ( + "context" + "database/sql/driver" +) + +// Deprecated: Drivers should implement ExecerContext instead. +func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) { + return stmt.conn.ExecContext(context.Background(), stmt.query, convertValueToNamedValue(args)) +} + +// Deprecated: Drivers should implement StmtQueryContext instead (or additionally). +func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) { + return stmt.conn.QueryContext(context.Background(), stmt.query, convertValueToNamedValue(args)) +} + +func convertValueToNamedValue(args []driver.Value) []driver.NamedValue { + namedArgs := make([]driver.NamedValue, len(args)) + for i, v := range args { + namedArgs[i] = driver.NamedValue{Ordinal: i + 1, Value: v} + } + return namedArgs +} diff --git a/statement_test.go b/statement_test.go new file mode 100644 index 0000000..fc4c541 --- /dev/null +++ b/statement_test.go @@ -0,0 +1,33 @@ +// +build go1.6 + +package sqlmock + +import ( + "errors" + "testing" +) + +func TestExpectedPreparedStatementCloseError(t *testing.T) { + conn, mock, err := New() + if err != nil { + t.Fatal("failed to open sqlmock database:", err) + } + + mock.ExpectBegin() + want := errors.New("STMT ERROR") + mock.ExpectPrepare("SELECT").WillReturnCloseError(want) + + txn, err := conn.Begin() + if err != nil { + t.Fatal("unexpected error while opening transaction:", err) + } + + stmt, err := txn.Prepare("SELECT") + if err != nil { + t.Fatal("unexpected error while preparing a statement:", err) + } + + if err := stmt.Close(); err != want { + t.Fatalf("got = %v, want = %v", err, want) + } +} diff --git a/stubs_test.go b/stubs_test.go new file mode 100644 index 0000000..b35be2e --- /dev/null +++ b/stubs_test.go @@ -0,0 +1,80 @@ +package sqlmock + +import ( + "database/sql/driver" + "errors" + "fmt" + "strconv" + "time" +) + +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +type NullInt struct { + Integer int + Valid bool +} + +// Satisfy sql.Scanner interface +func (ni *NullInt) Scan(value interface{}) error { + switch v := value.(type) { + case nil: + ni.Integer, ni.Valid = 0, false + case int64: + const maxUint = ^uint(0) + const maxInt = int(maxUint >> 1) + const minInt = -maxInt - 1 + + if v > int64(maxInt) || v < int64(minInt) { + return errors.New("value out of int range") + } + ni.Integer, ni.Valid = int(v), true + case []byte: + n, err := strconv.Atoi(string(v)) + if err != nil { + return err + } + ni.Integer, ni.Valid = n, true + case string: + n, err := strconv.Atoi(v) + if err != nil { + return err + } + ni.Integer, ni.Valid = n, true + default: + return fmt.Errorf("can't convert %T to integer", value) + } + return nil +} + +// Satisfy sql.Valuer interface. +func (ni NullInt) Value() (driver.Value, error) { + if !ni.Valid { + return nil, nil + } + return int64(ni.Integer), nil +} + +// Satisfy sql.Scanner interface +func (nt *NullTime) Scan(value interface{}) error { + switch v := value.(type) { + case nil: + nt.Time, nt.Valid = time.Time{}, false + case time.Time: + nt.Time, nt.Valid = v, true + default: + return fmt.Errorf("can't convert %T to time.Time", value) + } + return nil +} + +// Satisfy sql.Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +}