@@ -7,31 +7,41 @@ import (
77 "context"
88 "database/sql"
99 "fmt"
10+ "io/fs"
1011 "sync"
12+ "time"
1113
1214 "github.com/golang-migrate/migrate/v4"
1315 "github.com/golang-migrate/migrate/v4/database"
1416 migmysql "github.com/golang-migrate/migrate/v4/database/mysql"
1517 "github.com/golang-migrate/migrate/v4/source"
18+ "github.com/golang-migrate/migrate/v4/source/iofs"
1619
1720 "github.com/moov-io/base/log"
1821)
1922
2023var migrationMutex sync.Mutex
2124
22- func RunMigrations (logger log.Logger , config DatabaseConfig ) error {
25+ func RunMigrations (logger log.Logger , config DatabaseConfig , opts ... MigrateOption ) error {
2326 logger .Info ().Log ("Running Migrations" )
2427
25- source , driver , err := GetDriver (logger , config )
28+ // apply all of our optional arguments
29+ o := & migrateOptions {}
30+ for _ , opt := range opts {
31+ if err := opt (o ); err != nil {
32+ return err
33+ }
34+ }
35+
36+ source , driver , err := getDriver (logger , config , o )
2637 if err != nil {
2738 return err
2839 }
29-
3040 defer driver .Close ()
3141
3242 migrationMutex .Lock ()
3343 m , err := migrate .NewWithInstance (
34- "filtering-pkger" ,
44+ source . name ,
3545 source ,
3646 config .DatabaseName ,
3747 driver ,
@@ -40,6 +50,10 @@ func RunMigrations(logger log.Logger, config DatabaseConfig) error {
4050 return logger .Fatal ().LogErrorf ("Error running migration: %w" , err ).Err ()
4151 }
4252
53+ if o .timeout != nil {
54+ m .LockTimeout = * o .timeout
55+ }
56+
4357 err = m .Up ()
4458 migrationMutex .Unlock ()
4559
@@ -56,41 +70,63 @@ func RunMigrations(logger log.Logger, config DatabaseConfig) error {
5670 return nil
5771}
5872
73+ // Deprecated: Here to not break compatibility since it was once public.
5974func GetDriver (logger log.Logger , config DatabaseConfig ) (source.Driver , database.Driver , error ) {
75+ return getDriver (logger , config , & migrateOptions {})
76+ }
77+
78+ func getDriver (logger log.Logger , config DatabaseConfig , opts * migrateOptions ) (* SourceDriver , database.Driver , error ) {
79+ var err error
80+
6081 if config .MySQL != nil {
61- src , err := NewPkgerSource ("mysql" , true )
62- if err != nil {
63- return nil , nil , err
82+ if opts .source == nil {
83+ src , err := NewPkgerSource ("mysql" , true )
84+ if err != nil {
85+ return nil , nil , err
86+ }
87+ opts .source = & SourceDriver {
88+ name : "pkger-mysql" ,
89+ Driver : src ,
90+ }
6491 }
6592
66- db , err := New ( context . Background (), logger , config )
67- if err != nil {
68- return nil , nil , err
69- }
70- defer db . Close ()
93+ if opts . driver == nil {
94+ db , err := New ( context . Background (), logger , config )
95+ if err != nil {
96+ return nil , nil , err
97+ }
7198
72- drv , err := MySQLDriver (db )
73- if err != nil {
74- return nil , nil , err
99+ opts .driver , err = MySQLDriver (db )
100+ if err != nil {
101+ return nil , nil , err
102+ }
75103 }
76104
77- return src , drv , nil
78-
79105 } else if config .Spanner != nil {
80- src , err := NewPkgerSource ("spanner" , false )
81- if err != nil {
82- return nil , nil , err
106+ if opts .source == nil {
107+ src , err := NewPkgerSource ("spanner" , false )
108+ if err != nil {
109+ return nil , nil , err
110+ }
111+ opts .source = & SourceDriver {
112+ name : "pkger-spanner" ,
113+ Driver : src ,
114+ }
83115 }
84116
85- drv , err := SpannerDriver (config )
86- if err != nil {
87- return nil , nil , err
117+ if opts .driver == nil {
118+ opts .driver , err = SpannerDriver (config )
119+ if err != nil {
120+ return nil , nil , err
121+ }
88122 }
123+ }
89124
90- return src , drv , nil
125+ if opts .source == nil || opts .driver == nil {
126+ return nil , nil , fmt .Errorf ("database config not defined" )
91127 }
92128
93- return nil , nil , fmt . Errorf ( "database config not defined" )
129+ return opts . source , opts . driver , nil
94130}
95131
96132func MySQLDriver (db * sql.DB ) (database.Driver , error ) {
@@ -100,3 +136,38 @@ func MySQLDriver(db *sql.DB) (database.Driver, error) {
100136func SpannerDriver (config DatabaseConfig ) (database.Driver , error ) {
101137 return SpannerMigrationDriver (* config .Spanner , config .DatabaseName )
102138}
139+
140+ type MigrateOption func (o * migrateOptions ) error
141+
142+ type SourceDriver struct {
143+ name string
144+ source.Driver
145+ }
146+
147+ type migrateOptions struct {
148+ source * SourceDriver
149+ driver database.Driver
150+
151+ timeout * time.Duration
152+ }
153+
154+ func WithEmbeddedMigrations (f fs.FS ) MigrateOption {
155+ return func (o * migrateOptions ) error {
156+ src , err := iofs .New (f , "migrations" )
157+ if err != nil {
158+ return err
159+ }
160+ o .source = & SourceDriver {
161+ name : "embedded" ,
162+ Driver : src ,
163+ }
164+ return nil
165+ }
166+ }
167+
168+ func WithTimeout (dur time.Duration ) MigrateOption {
169+ return func (o * migrateOptions ) error {
170+ o .timeout = & dur
171+ return nil
172+ }
173+ }
0 commit comments