Skip to content

make sql statements customizable via options #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@ const defaultTableName = "migrations"

// Migrator is the migrator implementation
type Migrator struct {
tableName string
logger Logger
migrations []interface{}
tableName string
logger Logger
migrations []interface{}
createSQLFormat string
insertSQLFormat string
}

// Option sets options such migrations or table name.
type Option func(*Migrator)

type insertArgs struct {
sql string
id int
version string
}

// TableName creates an option to allow overriding the default table name
func TableName(tableName string) Option {
return func(m *Migrator) {
Expand Down Expand Up @@ -47,6 +55,20 @@ func WithLogger(logger Logger) Option {
}
}

// WithCreateSQLFormat creates an option to allow overriding the create SQL script format
func WithCreateSQLFormat(createSQL string) Option {
return func(m *Migrator) {
m.createSQLFormat = createSQL
}
}

// WithInsertSQLFormat creates an option to allow overriding the insert SQL script format
func WithInsertSQLFormat(insertSQL string) Option {
return func(m *Migrator) {
m.insertSQLFormat = insertSQL
}
}

// Migrations creates an option with provided migrations
func Migrations(migrations ...interface{}) Option {
return func(m *Migrator) {
Expand All @@ -59,6 +81,14 @@ func New(opts ...Option) (*Migrator, error) {
m := &Migrator{
logger: log.New(os.Stdout, "migrator: ", 0),
tableName: defaultTableName,
createSQLFormat: `
CREATE TABLE IF NOT EXISTS %s (
id INT8 NOT NULL,
version VARCHAR(255) NOT NULL,
PRIMARY KEY (id)
);
`,
insertSQLFormat: "INSERT INTO %s (id, version) VALUES (?, ?)",
}
for _, opt := range opts {
opt(m)
Expand All @@ -83,13 +113,7 @@ func New(opts ...Option) (*Migrator, error) {
// Migrate applies all available migrations
func (m *Migrator) Migrate(db *sql.DB) error {
// create migrations table if doesn't exist
_, err := db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id INT8 NOT NULL,
version VARCHAR(255) NOT NULL,
PRIMARY KEY (id)
);
`, m.tableName))
_, err := db.Exec(fmt.Sprintf(m.createSQLFormat, m.tableName))
if err != nil {
return err
}
Expand All @@ -106,14 +130,17 @@ func (m *Migrator) Migrate(db *sql.DB) error {

// plan migrations
for idx, migration := range m.migrations[count:len(m.migrations)] {
insertVersion := fmt.Sprintf("INSERT INTO %s (id, version) VALUES (%d, '%s')", m.tableName, idx+count, migration.(fmt.Stringer).String())
sqlStmt := fmt.Sprintf(m.insertSQLFormat, m.tableName)
insertID := idx + count
insertVersion := migration.(fmt.Stringer).String()

switch mig := migration.(type) {
case *Migration:
if err := migrate(db, m.logger, insertVersion, mig); err != nil {
if err := migrate(db, m.logger, insertArgs{sql: sqlStmt, id: insertID, version: insertVersion,}, mig); err != nil {
return fmt.Errorf("migrator: error while running migrations: %v", err)
}
case *MigrationNoTx:
if err := migrateNoTx(db, m.logger, insertVersion, mig); err != nil {
if err := migrateNoTx(db, m.logger, insertArgs{sql: sqlStmt, id: insertID, version: insertVersion,}, mig); err != nil {
return fmt.Errorf("migrator: error while running migrations: %v", err)
}
}
Expand Down Expand Up @@ -173,7 +200,7 @@ func (m *MigrationNoTx) String() string {
return m.Name
}

func migrate(db *sql.DB, logger Logger, insertVersion string, migration *Migration) error {
func migrate(db *sql.DB, logger Logger, args insertArgs, migration *Migration) error {
tx, err := db.Begin()
if err != nil {
return err
Expand All @@ -191,20 +218,23 @@ func migrate(db *sql.DB, logger Logger, insertVersion string, migration *Migrati
if err = migration.Func(tx); err != nil {
return fmt.Errorf("error executing golang migration: %s", err)
}
if _, err = tx.Exec(insertVersion); err != nil {

if _, err = tx.Exec(args.sql, args.id, args.version); err != nil {
return fmt.Errorf("error updating migration versions: %s", err)
}
logger.Printf("applied migration named '%s'", migration.Name)

return err
}

func migrateNoTx(db *sql.DB, logger Logger, insertVersion string, migration *MigrationNoTx) error {
func migrateNoTx(db *sql.DB, logger Logger, args insertArgs, migration *MigrationNoTx) error {
logger.Printf("applying no tx migration named '%s'...", migration.Name)

if err := migration.Func(db); err != nil {
return fmt.Errorf("error executing golang migration: %s", err)
}
if _, err := db.Exec(insertVersion); err != nil {

if _, err := db.Exec(args.sql, args.id, args.version); err != nil {
return fmt.Errorf("error updating migration versions: %s", err)
}
logger.Printf("applied no tx migration named '%s'", migration.Name)
Expand Down