diff --git a/cmd/cli/handler.go b/cmd/cli/handler.go index ce650479292..da2d3746452 100644 --- a/cmd/cli/handler.go +++ b/cmd/cli/handler.go @@ -16,7 +16,7 @@ type Handler struct { func NewHandler(slOpts []servicelocatorx.Option, dOpts []driver.OptionsModifier, cOpts []configx.OptionModifier) *Handler { return &Handler{ - Migration: newMigrateHandler(), + Migration: newMigrateHandler(slOpts, dOpts, cOpts), Janitor: NewJanitorHandler(slOpts, dOpts, cOpts), } } diff --git a/cmd/cli/handler_migrate.go b/cmd/cli/handler_migrate.go index 0172584028f..a4c2dd4885d 100644 --- a/cmd/cli/handler_migrate.go +++ b/cmd/cli/handler_migrate.go @@ -34,10 +34,18 @@ import ( "github.com/ory/x/flagx" ) -type MigrateHandler struct{} +type MigrateHandler struct { + slOpts []servicelocatorx.Option + dOpts []driver.OptionsModifier + cOpts []configx.OptionModifier +} -func newMigrateHandler() *MigrateHandler { - return &MigrateHandler{} +func newMigrateHandler(slOpts []servicelocatorx.Option, dOpts []driver.OptionsModifier, cOpts []configx.OptionModifier) *MigrateHandler { + return &MigrateHandler{ + slOpts: slOpts, + dOpts: dOpts, + cOpts: cOpts, + } } const ( @@ -262,21 +270,21 @@ func (h *MigrateHandler) MigrateGen(cmd *cobra.Command, args []string) { os.Exit(0) } -func makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, err error) { +func (h *MigrateHandler) makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, err error) { var d driver.Registry if flagx.MustGetBool(cmd, "read-from-env") { d, err = driver.New( cmd.Context(), servicelocatorx.NewOptions(), - []driver.OptionsModifier{ + append([]driver.OptionsModifier{ driver.WithOptions( configx.SkipValidation(), configx.WithFlags(cmd.Flags())), driver.DisableValidation(), driver.DisablePreloading(), driver.SkipNetworkInit(), - }) + }, h.dOpts...)) if err != nil { return nil, err } @@ -292,7 +300,7 @@ func makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, d, err = driver.New( cmd.Context(), servicelocatorx.NewOptions(), - []driver.OptionsModifier{ + append([]driver.OptionsModifier{ driver.WithOptions( configx.WithFlags(cmd.Flags()), configx.SkipValidation(), @@ -301,7 +309,7 @@ func makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, driver.DisableValidation(), driver.DisablePreloading(), driver.SkipNetworkInit(), - }) + }, h.dOpts...)) if err != nil { return nil, err } @@ -310,7 +318,7 @@ func makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, } func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string) (err error) { - p, err := makePersister(cmd, args) + p, err := h.makePersister(cmd, args) if err != nil { return err } @@ -360,7 +368,7 @@ func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string) (err erro } func (h *MigrateHandler) MigrateStatus(cmd *cobra.Command, args []string) error { - p, err := makePersister(cmd, args) + p, err := h.makePersister(cmd, args) if err != nil { return err } diff --git a/cmd/migrate_status.go b/cmd/migrate_status.go index 397f5e86e48..940728b79b6 100644 --- a/cmd/migrate_status.go +++ b/cmd/migrate_status.go @@ -4,6 +4,7 @@ package cmd import ( + "github.com/ory/x/cmdx" "github.com/ory/x/configx" "github.com/ory/x/servicelocatorx" @@ -20,6 +21,7 @@ func NewMigrateStatusCmd(slOpts []servicelocatorx.Option, dOpts []driver.Options RunE: cli.NewHandler(slOpts, dOpts, cOpts).Migration.MigrateStatus, } + cmdx.RegisterFormatFlags(cmd.PersistentFlags()) cmd.Flags().BoolP("read-from-env", "e", false, "If set, reads the database connection string from the environment variable DSN or config file key dsn.") cmd.Flags().Bool("block", false, "Block until all migrations have been applied") diff --git a/driver/factory.go b/driver/factory.go index 2e5fe949c29..4b206ac71c7 100644 --- a/driver/factory.go +++ b/driver/factory.go @@ -5,6 +5,7 @@ package driver import ( "context" + "io/fs" "github.com/ory/hydra/v2/driver/config" "github.com/ory/x/configx" @@ -22,7 +23,10 @@ type ( // The first default refers to determining the NID at startup; the second default referes to the fact that the Contextualizer may dynamically change the NID. skipNetworkInit bool tracerWrapper TracerWrapper + extraMigrations []fs.FS } + OptionsModifier func(*options) + TracerWrapper func(*otelx.Tracer) *otelx.Tracer ) @@ -34,14 +38,12 @@ func newOptions() *options { } } -func WithConfig(config *config.DefaultProvider) func(o *options) { +func WithConfig(config *config.DefaultProvider) OptionsModifier { return func(o *options) { o.config = config } } -type OptionsModifier func(*options) - func WithOptions(opts ...configx.OptionModifier) OptionsModifier { return func(o *options) { o.opts = append(o.opts, opts...) @@ -77,6 +79,13 @@ func WithTracerWrapper(wrapper TracerWrapper) OptionsModifier { } } +// WithExtraMigrations specifies additional database migration. +func WithExtraMigrations(m ...fs.FS) OptionsModifier { + return func(o *options) { + o.extraMigrations = append(o.extraMigrations, m...) + } +} + func New(ctx context.Context, sl *servicelocatorx.Options, opts []OptionsModifier) (Registry, error) { o := newOptions() for _, f := range opts { @@ -115,7 +124,7 @@ func New(ctx context.Context, sl *servicelocatorx.Options, opts []OptionsModifie r.WithTracerWrapper(o.tracerWrapper) } - if err = r.Init(ctx, o.skipNetworkInit, false, ctxter); err != nil { + if err = r.Init(ctx, o.skipNetworkInit, false, ctxter, o.extraMigrations); err != nil { l.WithError(err).Error("Unable to initialize service registry.") return nil, err } diff --git a/driver/registry.go b/driver/registry.go index c75213e52c7..4c956c4cd48 100644 --- a/driver/registry.go +++ b/driver/registry.go @@ -5,6 +5,7 @@ package driver import ( "context" + "io/fs" "net/http" "go.opentelemetry.io/otel/trace" @@ -44,7 +45,7 @@ import ( type Registry interface { dbal.Driver - Init(ctx context.Context, skipNetworkInit bool, migrate bool, ctxer contextx.Contextualizer) error + Init(ctx context.Context, skipNetworkInit bool, migrate bool, ctxer contextx.Contextualizer, extraMigrations []fs.FS) error WithBuildInfo(v, h, d string) Registry WithConfig(c *config.DefaultProvider) Registry @@ -89,7 +90,7 @@ func NewRegistryFromDSN(ctx context.Context, c *config.DefaultProvider, l *logru if err != nil { return nil, err } - if err := registry.Init(ctx, skipNetworkInit, migrate, ctxer); err != nil { + if err := registry.Init(ctx, skipNetworkInit, migrate, ctxer, nil); err != nil { return nil, err } return registry, nil diff --git a/driver/registry_base_test.go b/driver/registry_base_test.go index 4e0f80ef859..4dedab5dead 100644 --- a/driver/registry_base_test.go +++ b/driver/registry_base_test.go @@ -67,7 +67,7 @@ func TestRegistryBase_newKeyStrategy_handlesNetworkError(t *testing.T) { r := registry.(*RegistrySQL) r.initialPing = failedPing(errors.New("snizzles")) - _ = r.Init(context.Background(), true, false, &contextx.TestContextualizer{}) + _ = r.Init(context.Background(), true, false, &contextx.TestContextualizer{}, nil) registryBase := RegistryBase{r: r, l: l} registryBase.WithConfig(c) diff --git a/driver/registry_sql.go b/driver/registry_sql.go index 7660a884c90..361d1aa154d 100644 --- a/driver/registry_sql.go +++ b/driver/registry_sql.go @@ -5,6 +5,7 @@ package driver import ( "context" + "io/fs" "strings" "time" @@ -64,7 +65,11 @@ func NewRegistrySQL() *RegistrySQL { } func (m *RegistrySQL) Init( - ctx context.Context, skipNetworkInit bool, migrate bool, ctxer contextx.Contextualizer, + ctx context.Context, + skipNetworkInit bool, + migrate bool, + ctxer contextx.Contextualizer, + extraMigrations []fs.FS, ) error { if m.persister == nil { m.WithContextualizer(ctxer) @@ -100,7 +105,7 @@ func (m *RegistrySQL) Init( return errorsx.WithStack(err) } - p, err := sql.NewPersister(ctx, c, m, m.Config(), m.l) + p, err := sql.NewPersister(ctx, c, m, m.Config(), extraMigrations) if err != nil { return err } diff --git a/driver/registry_sql_test.go b/driver/registry_sql_test.go index ffbc071f1b5..cd126a3f711 100644 --- a/driver/registry_sql_test.go +++ b/driver/registry_sql_test.go @@ -31,7 +31,7 @@ func TestDefaultKeyManager_HsmDisabled(t *testing.T) { reg, err := NewRegistryWithoutInit(c, l) r := reg.(*RegistrySQL) r.initialPing = sussessfulPing() - if err := r.Init(context.Background(), true, false, &contextx.Default{}); err != nil { + if err := r.Init(context.Background(), true, false, &contextx.Default{}, nil); err != nil { t.Fatalf("unable to init registry: %s", err) } assert.NoError(t, err) diff --git a/go.mod b/go.mod index baa5d57cd2b..37b051d1756 100644 --- a/go.mod +++ b/go.mod @@ -170,6 +170,7 @@ require ( github.com/knadh/koanf/v2 v2.0.1 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect + github.com/laher/mergefs v0.1.1 github.com/lib/pq v1.10.7 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect diff --git a/go.sum b/go.sum index 8dc45a304d3..b8bc9b59b96 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,7 @@ github.com/cockroachdb/cockroach-go/v2 v2.2.16/go.mod h1:xZ2VHjUEb/cySv0scXBx7Ys github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U= github.com/containerd/continuity v0.3.0 h1:nisirsYROK15TAMVukJOUyGJjz4BNQJBVsNvAXZJ/eg= github.com/containerd/continuity v0.3.0/go.mod h1:wJEAIwKOm/pBZuBd0JmeTvnLquTB1Ag8espWhkykbPM= +github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -549,6 +550,8 @@ github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/laher/mergefs v0.1.1 h1:nV2bTS57vrmbMxeR6uvJpI8LyGl3QHj4bLBZO3aUV58= +github.com/laher/mergefs v0.1.1/go.mod h1:FSY1hYy94on4Tz60waRMGdO1awwS23BacqJlqf9lJ9Q= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -570,6 +573,7 @@ github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJ github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= github.com/markbates/pkger v0.17.1 h1:/MKEtWqtc0mZvu9OinB9UzVN9iYCwLWuyUv4Bw+PCno= github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= +github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= @@ -600,6 +604,7 @@ github.com/mikefarah/yq/v4 v4.16.1/go.mod h1:mfI3lycn5DjU6N4kfpiR4S7ylu0xZj9XgKS github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= @@ -667,6 +672,7 @@ github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCko github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 h1:JhzVVoYvbOACxoUmOs6V/G4D5nPVUW73rKvXxP4XUJc= +github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -745,6 +751,7 @@ github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrf github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d h1:yKm7XZV6j9Ev6lojP2XaIshpT4ymkqhMeSghO5Ps00E= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e h1:qpG93cPwA5f7s/ZPBJnGOYQNK/vKsaDaseuKT5Asee8= @@ -1300,6 +1307,7 @@ google.golang.org/grpc v1.42.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ5 google.golang.org/grpc v1.55.0 h1:3Oj82/tFSCeUrRTg/5E/7d/W5A1tj6Ky1ABAuZuv5ag= google.golang.org/grpc v1.55.0/go.mod h1:iYEXKGkEBhg1PjZQvoYEVPTDkHo1/bjTnfwTeGONTY8= google.golang.org/grpc/examples v0.0.0-20210304020650-930c79186c99 h1:qA8rMbz1wQ4DOFfM2ouD29DG9aHWBm6ZOy9BGxiUMmY= +google.golang.org/grpc/examples v0.0.0-20210304020650-930c79186c99/go.mod h1:Ly7ZA/ARzg8fnPU9TyZIxoz33sEUuWX7txiqs8lPTgE= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/hsm/manager_hsm_test.go b/hsm/manager_hsm_test.go index cf629bd62c9..e7bc145180e 100644 --- a/hsm/manager_hsm_test.go +++ b/hsm/manager_hsm_test.go @@ -52,7 +52,7 @@ func TestDefaultKeyManager_HSMEnabled(t *testing.T) { reg.WithLogger(l) reg.WithConfig(c) reg.WithHsmContext(mockHsmContext) - err := reg.Init(context.Background(), false, true, &contextx.TestContextualizer{}) + err := reg.Init(context.Background(), false, true, &contextx.TestContextualizer{}, nil) assert.NoError(t, err) assert.IsType(t, &jwk.ManagerStrategy{}, reg.KeyManager()) assert.IsType(t, &sql.Persister{}, reg.SoftwareKeyManager()) diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index e454f527fac..c9960a9281e 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -6,10 +6,12 @@ package sql import ( "context" "database/sql" + "io/fs" "reflect" "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" + "github.com/laher/mergefs" "github.com/pkg/errors" @@ -104,8 +106,8 @@ func (p *Persister) Rollback(ctx context.Context) (err error) { return errorsx.WithStack(tx.TX.Rollback()) } -func NewPersister(ctx context.Context, c *pop.Connection, r Dependencies, config *config.DefaultProvider, l *logrusx.Logger) (*Persister, error) { - mb, err := popx.NewMigrationBox(migrations, popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0)) +func NewPersister(ctx context.Context, c *pop.Connection, r Dependencies, config *config.DefaultProvider, extraMigrations []fs.FS) (*Persister, error) { + mb, err := popx.NewMigrationBox(mergefs.Merge(append([]fs.FS{migrations}, extraMigrations...)...), popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0)) if err != nil { return nil, errorsx.WithStack(err) } @@ -115,7 +117,7 @@ func NewPersister(ctx context.Context, c *pop.Connection, r Dependencies, config mb: mb, r: r, config: config, - l: l, + l: r.Logger(), p: networkx.NewManager(c, r.Logger(), r.Tracer(ctx)), }, nil }