diff --git a/connection.go b/connection.go index cb58a839..c71b2e53 100644 --- a/connection.go +++ b/connection.go @@ -90,14 +90,23 @@ func (c *Connection) Open() error { if c.Store != nil { return nil } + if c.Dialect == nil { + return errors.New("invalid connection instance") + } details := c.Dialect.Details() db, err := sqlx.Open(details.Dialect, c.Dialect.URL()) + if err != nil { + return errors.Wrap(err, "could not open database connection") + } db.SetMaxOpenConns(details.Pool) db.SetMaxIdleConns(details.IdlePool) - if err == nil { - c.Store = &dB{db} + c.Store = &dB{db} + + err = c.Dialect.afterOpen(c) + if err != nil { + c.Store = nil } - return errors.Wrap(err, "couldn't connect to database") + return errors.Wrap(err, "could not open database connection") } // Close destroys an active datasource connection diff --git a/connection_test.go b/connection_test.go index d1d86cf5..5fc8b284 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1 +1,52 @@ package pop + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_Connection_SimpleFlow(t *testing.T) { + r := require.New(t) + + cd := &ConnectionDetails{ + URL: "sqlite:///foo.db", + } + c, err := NewConnection(cd) + r.NoError(err) + + err = c.Open() + r.NoError(err) + err = c.Open() // open again + r.NoError(err) + err = c.Close() + r.NoError(err) +} + +func Test_Connection_Open_NoDialect(t *testing.T) { + r := require.New(t) + + cd := &ConnectionDetails{ + URL: "sqlite:///foo.db", + } + c, err := NewConnection(cd) + r.NoError(err) + + c.Dialect = nil + err = c.Open() + r.Error(err) +} + +func Test_Connection_Open_BadDialect(t *testing.T) { + r := require.New(t) + + cd := &ConnectionDetails{ + URL: "sqlite:///foo.db", + } + c, err := NewConnection(cd) + r.NoError(err) + + cd.Dialect = "unknown" + err = c.Open() + r.Error(err) +} diff --git a/dialect.go b/dialect.go index e28e8821..11b961d1 100644 --- a/dialect.go +++ b/dialect.go @@ -40,6 +40,7 @@ type dialect interface { FizzTranslator() fizz.Translator Lock(func() error) error TruncateAll(*Connection) error + afterOpen(*Connection) error } func genericCreate(s store, model *Model, cols columns.Columns) error { diff --git a/dialect_cockroach.go b/dialect_cockroach.go index cc8ab3eb..87759657 100644 --- a/dialect_cockroach.go +++ b/dialect_cockroach.go @@ -32,19 +32,19 @@ func init() { var _ dialect = &cockroach{} // ServerInfo holds informational data about connected database server. -type ServerInfo struct { +type cockroachInfo struct { VersionString string `db:"version"` - Product string `db:"-"` - License string `db:"-"` - Version string `db:"-"` - BuildInfo string `db:"-"` + product string `db:"-"` + license string `db:"-"` + version string `db:"-"` + buildInfo string `db:"-"` } type cockroach struct { translateCache map[string]string mu sync.Mutex ConnectionDetails *ConnectionDetails - Server ServerInfo + info cockroachInfo } func (p *cockroach) Name() string { @@ -218,33 +218,13 @@ func (p *cockroach) LoadSchema(r io.Reader) error { return genericLoadSchema(p.ConnectionDetails, p.MigrationURL(), r) } -func (p *cockroach) FillServerInfo(tx *Connection) error { - if err := tx.RawQuery(`select version() AS "version"`).First(&p.Server); err != nil { - return err - } - if s := strings.Split(p.Server.VersionString, " "); len(s) > 3 { - p.Server.Product = s[0] - p.Server.License = s[1] - p.Server.Version = s[2] - p.Server.BuildInfo = s[3] - } - log(logging.Debug, "server: %v %v %v", p.Server.Product, p.Server.License, p.Server.Version) - - return nil -} - func (p *cockroach) TruncateAll(tx *Connection) error { type table struct { TableName string `db:"table_name"` } - // move it to `newCockroach()` if it need more - if err := p.FillServerInfo(tx); err != nil { - return err - } - tableQuery := "select table_name from information_schema.tables where table_schema = 'public' and table_type = 'BASE TABLE' and table_catalog = ?" - if strings.HasPrefix(p.Server.Version, "v1") { + if strings.HasPrefix(p.info.version, "v1") { tableQuery = "select table_name from information_schema.tables where table_schema = ?" } @@ -272,6 +252,21 @@ func (p *cockroach) TruncateAll(tx *Connection) error { // return tx3.RawQuery(fmt.Sprintf("truncate %s cascade;", strings.Join(tableNames, ", "))).Exec() } +func (p *cockroach) afterOpen(c *Connection) error { + if err := c.RawQuery(`select version() AS "version"`).First(&p.info); err != nil { + return err + } + if s := strings.Split(p.info.VersionString, " "); len(s) > 3 { + p.info.product = s[0] + p.info.license = s[1] + p.info.version = s[2] + p.info.buildInfo = s[3] + } + log(logging.Debug, "server: %v %v %v", p.info.product, p.info.license, p.info.version) + + return nil +} + func newCockroach(deets *ConnectionDetails) (dialect, error) { deets.Dialect = "postgres" cd := &cockroach{ diff --git a/dialect_mysql.go b/dialect_mysql.go index 833b5c00..a4013663 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -185,6 +185,10 @@ func (m *mysql) TruncateAll(tx *Connection) error { return tx.RawQuery(qb.String()).Exec() } +func (m *mysql) afterOpen(c *Connection) error { + return nil +} + func newMySQL(deets *ConnectionDetails) (dialect, error) { cd := &mysql{ ConnectionDetails: deets, diff --git a/dialect_postgresql.go b/dialect_postgresql.go index ec9578d5..8597d5c2 100644 --- a/dialect_postgresql.go +++ b/dialect_postgresql.go @@ -199,6 +199,10 @@ func (p *postgresql) TruncateAll(tx *Connection) error { return tx.RawQuery(fmt.Sprintf(pgTruncate, tx.MigrationTableName())).Exec() } +func (p *postgresql) afterOpen(c *Connection) error { + return nil +} + func newPostgreSQL(deets *ConnectionDetails) (dialect, error) { cd := &postgresql{ ConnectionDetails: deets, diff --git a/dialect_sqlite.go b/dialect_sqlite.go index d3e378bc..b6efd5b4 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -211,6 +211,10 @@ func (m *sqlite) TruncateAll(tx *Connection) error { return tx.RawQuery(strings.Join(stmts, "; ")).Exec() } +func (m *sqlite) afterOpen(c *Connection) error { + return nil +} + func newSQLite(deets *ConnectionDetails) (dialect, error) { deets.URL = fmt.Sprintf("sqlite3://%s", deets.Database) cd := &sqlite{