Skip to content

Commit

Permalink
Improving connection (#323)
Browse files Browse the repository at this point in the history
* fixed to prevent panics by wrong usage

* added afterOpen to give dialect a chance to initial setup
  • Loading branch information
sio4 authored and stanislas-m committed Dec 2, 2018
1 parent 29bb60f commit 264e655
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 30 deletions.
15 changes: 12 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,23 @@ func (c *Connection) Open() error {
if c.Store != nil {
return nil
}
if c.Dialect == nil {
return errors.New("invalid connection instance")
}
details := c.Dialect.Details()
db, err := sqlx.Open(details.Dialect, c.Dialect.URL())
if err != nil {
return errors.Wrap(err, "could not open database connection")
}
db.SetMaxOpenConns(details.Pool)
db.SetMaxIdleConns(details.IdlePool)
if err == nil {
c.Store = &dB{db}
c.Store = &dB{db}

err = c.Dialect.afterOpen(c)
if err != nil {
c.Store = nil
}
return errors.Wrap(err, "couldn't connect to database")
return errors.Wrap(err, "could not open database connection")
}

// Close destroys an active datasource connection
Expand Down
51 changes: 51 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
@@ -1 +1,52 @@
package pop

import (
"testing"

"github.com/stretchr/testify/require"
)

func Test_Connection_SimpleFlow(t *testing.T) {
r := require.New(t)

cd := &ConnectionDetails{
URL: "sqlite:///foo.db",
}
c, err := NewConnection(cd)
r.NoError(err)

err = c.Open()
r.NoError(err)
err = c.Open() // open again
r.NoError(err)
err = c.Close()
r.NoError(err)
}

func Test_Connection_Open_NoDialect(t *testing.T) {
r := require.New(t)

cd := &ConnectionDetails{
URL: "sqlite:///foo.db",
}
c, err := NewConnection(cd)
r.NoError(err)

c.Dialect = nil
err = c.Open()
r.Error(err)
}

func Test_Connection_Open_BadDialect(t *testing.T) {
r := require.New(t)

cd := &ConnectionDetails{
URL: "sqlite:///foo.db",
}
c, err := NewConnection(cd)
r.NoError(err)

cd.Dialect = "unknown"
err = c.Open()
r.Error(err)
}
1 change: 1 addition & 0 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type dialect interface {
FizzTranslator() fizz.Translator
Lock(func() error) error
TruncateAll(*Connection) error
afterOpen(*Connection) error
}

func genericCreate(s store, model *Model, cols columns.Columns) error {
Expand Down
49 changes: 22 additions & 27 deletions dialect_cockroach.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ func init() {
var _ dialect = &cockroach{}

// ServerInfo holds informational data about connected database server.
type ServerInfo struct {
type cockroachInfo struct {
VersionString string `db:"version"`
Product string `db:"-"`
License string `db:"-"`
Version string `db:"-"`
BuildInfo string `db:"-"`
product string `db:"-"`
license string `db:"-"`
version string `db:"-"`
buildInfo string `db:"-"`
}

type cockroach struct {
translateCache map[string]string
mu sync.Mutex
ConnectionDetails *ConnectionDetails
Server ServerInfo
info cockroachInfo
}

func (p *cockroach) Name() string {
Expand Down Expand Up @@ -218,33 +218,13 @@ func (p *cockroach) LoadSchema(r io.Reader) error {
return genericLoadSchema(p.ConnectionDetails, p.MigrationURL(), r)
}

func (p *cockroach) FillServerInfo(tx *Connection) error {
if err := tx.RawQuery(`select version() AS "version"`).First(&p.Server); err != nil {
return err
}
if s := strings.Split(p.Server.VersionString, " "); len(s) > 3 {
p.Server.Product = s[0]
p.Server.License = s[1]
p.Server.Version = s[2]
p.Server.BuildInfo = s[3]
}
log(logging.Debug, "server: %v %v %v", p.Server.Product, p.Server.License, p.Server.Version)

return nil
}

func (p *cockroach) TruncateAll(tx *Connection) error {
type table struct {
TableName string `db:"table_name"`
}

// move it to `newCockroach()` if it need more
if err := p.FillServerInfo(tx); err != nil {
return err
}

tableQuery := "select table_name from information_schema.tables where table_schema = 'public' and table_type = 'BASE TABLE' and table_catalog = ?"
if strings.HasPrefix(p.Server.Version, "v1") {
if strings.HasPrefix(p.info.version, "v1") {
tableQuery = "select table_name from information_schema.tables where table_schema = ?"
}

Expand Down Expand Up @@ -272,6 +252,21 @@ func (p *cockroach) TruncateAll(tx *Connection) error {
// return tx3.RawQuery(fmt.Sprintf("truncate %s cascade;", strings.Join(tableNames, ", "))).Exec()
}

func (p *cockroach) afterOpen(c *Connection) error {
if err := c.RawQuery(`select version() AS "version"`).First(&p.info); err != nil {
return err
}
if s := strings.Split(p.info.VersionString, " "); len(s) > 3 {
p.info.product = s[0]
p.info.license = s[1]
p.info.version = s[2]
p.info.buildInfo = s[3]
}
log(logging.Debug, "server: %v %v %v", p.info.product, p.info.license, p.info.version)

return nil
}

func newCockroach(deets *ConnectionDetails) (dialect, error) {
deets.Dialect = "postgres"
cd := &cockroach{
Expand Down
4 changes: 4 additions & 0 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ func (m *mysql) TruncateAll(tx *Connection) error {
return tx.RawQuery(qb.String()).Exec()
}

func (m *mysql) afterOpen(c *Connection) error {
return nil
}

func newMySQL(deets *ConnectionDetails) (dialect, error) {
cd := &mysql{
ConnectionDetails: deets,
Expand Down
4 changes: 4 additions & 0 deletions dialect_postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ func (p *postgresql) TruncateAll(tx *Connection) error {
return tx.RawQuery(fmt.Sprintf(pgTruncate, tx.MigrationTableName())).Exec()
}

func (p *postgresql) afterOpen(c *Connection) error {
return nil
}

func newPostgreSQL(deets *ConnectionDetails) (dialect, error) {
cd := &postgresql{
ConnectionDetails: deets,
Expand Down
4 changes: 4 additions & 0 deletions dialect_sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ func (m *sqlite) TruncateAll(tx *Connection) error {
return tx.RawQuery(strings.Join(stmts, "; ")).Exec()
}

func (m *sqlite) afterOpen(c *Connection) error {
return nil
}

func newSQLite(deets *ConnectionDetails) (dialect, error) {
deets.URL = fmt.Sprintf("sqlite3://%s", deets.Database)
cd := &sqlite{
Expand Down

0 comments on commit 264e655

Please sign in to comment.