From 007f941bd0061dac1e8ba0467cc29f191fb78dd4 Mon Sep 17 00:00:00 2001 From: Brandur Date: Sat, 13 Jul 2024 11:45:20 -0700 Subject: [PATCH] Break out CLI functionality to make it modular + `--line` flag Here, break out the CLI so that most of its implementation no longer lives in a `main.go`, and rather in a subpackage that can be imported from somewhere else. This lets the majority of it to be reused in another executable and augmented in various ways. An addition is a new `DriverProcurer` interface that can provide a driver for various databases: type DriverProcurer interface { ProcurePgxV5(pool *pgxpool.Pool) riverdriver.Driver[pgx.Tx] } For the main CLI, this gets a trivial implementation using `riverpgxv5`, but could potentially be reimplemented elsewhere to swap in something else; type DriverProcurer struct{} func (p *DriverProcurer) ProcurePgxV5(pool *pgxpool.Pool) riverdriver.Driver[pgx.Tx] { return riverpgxv5.New(pool) } To make this more workable, I end up going through and refactoring quite a lot of code, making it into a bit of a mini framework, and one that could potentially support additional databases in the future without having to refactor the world again. We also add a `--line` flag to support the feature from #435 to the CLI. --- CHANGELOG.md | 1 + cmd/river/main.go | 505 +----------------- cmd/river/riverbench/river_bench.go | 33 +- cmd/river/rivercli/command.go | 154 ++++++ cmd/river/rivercli/river_cli.go | 480 +++++++++++++++++ .../river_cli_test.go} | 2 +- 6 files changed, 661 insertions(+), 514 deletions(-) create mode 100644 cmd/river/rivercli/command.go create mode 100644 cmd/river/rivercli/river_cli.go rename cmd/river/{main_test.go => rivercli/river_cli_test.go} (95%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 33d8e128..cb7d6222 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ river migrate-up --database-url "$DATABASE_URL" - Fully functional driver for `database/sql` for use with packages like Bun and GORM. [PR #351](https://github.com/riverqueue/river/pull/351). - Queues can be added after a client is initialized using `client.Queues().Add(queueName string, queueConfig QueueConfig)`. [PR #410](https://github.com/riverqueue/river/pull/410). - Migration that adds a `line` column to the `river_migration` table so that it can support multiple migration lines. [PR #435](https://github.com/riverqueue/river/pull/435). +- `--line` flag added to the River CLI. [PR #454](https://github.com/riverqueue/river/pull/454). ### Changed diff --git a/cmd/river/main.go b/cmd/river/main.go index 8877a375..a2ed1f85 100644 --- a/cmd/river/main.go +++ b/cmd/river/main.go @@ -1,514 +1,29 @@ package main import ( - "context" - "errors" - "fmt" - "io" - "log/slog" "os" - "slices" - "strconv" - "strings" - "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" - "github.com/lmittmann/tint" - "github.com/spf13/cobra" - "github.com/riverqueue/river/cmd/river/riverbench" + "github.com/riverqueue/river/cmd/river/rivercli" + "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverpgxv5" - "github.com/riverqueue/river/rivermigrate" ) -func main() { - var rootOpts struct { - Debug bool - Verbose bool - } - - rootCmd := &cobra.Command{ - Use: "river", - Short: "Provides command line facilities for the River job queue", - Long: strings.TrimSpace(` -Provides command line facilities for the River job queue. - `), - Run: func(cmd *cobra.Command, args []string) { - _ = cmd.Usage() - }, - } - rootCmd.PersistentFlags().BoolVar(&rootOpts.Debug, "debug", false, "output maximum logging verbosity (debug level)") - rootCmd.PersistentFlags().BoolVarP(&rootOpts.Verbose, "verbose", "v", false, "output additional logging verbosity (info level)") - rootCmd.MarkFlagsMutuallyExclusive("debug", "verbose") - - ctx := context.Background() - - execHandlingError := func(f func() (bool, error)) { - ok, err := f() - if err != nil { - fmt.Fprintf(os.Stderr, "failed: %s\n", err) - } - if err != nil || !ok { - os.Exit(1) - } - } - - makeLogger := func() *slog.Logger { - switch { - case rootOpts.Debug: - return slog.New(tint.NewHandler(os.Stdout, &tint.Options{Level: slog.LevelDebug})) - case rootOpts.Verbose: - return slog.New(tint.NewHandler(os.Stdout, nil)) - default: - return slog.New(tint.NewHandler(os.Stdout, &tint.Options{Level: slog.LevelWarn})) - } - } - - mustMarkFlagRequired := func(cmd *cobra.Command, name string) { - // We just panic here because this will never happen outside of an error - // in development. - if err := cmd.MarkFlagRequired(name); err != nil { - panic(err) - } - } - - // bench - { - var opts benchOpts - - cmd := &cobra.Command{ - Use: "bench", - Short: "Run River benchmark", - Long: strings.TrimSpace(` -Run a River benchmark which inserts and works jobs continually, giving a rough -idea of jobs per second and time to work a single job. - -By default, the benchmark will continuously insert and work jobs in perpetuity -until interrupted by SIGINT (Ctrl^C). It can alternatively take a maximum run -duration with --duration, which takes a Go-style duration string like 1m. -Lastly, it can take --num-total-jobs, which inserts the given number of jobs -before starting the client, and works until all jobs are finished. - -The database in --database-url will have its jobs table truncated, so make sure -to use a development database only. - `), - Run: func(cmd *cobra.Command, args []string) { - execHandlingError(func() (bool, error) { return bench(ctx, makeLogger(), os.Stdout, &opts) }) - }, - } - cmd.Flags().StringVar(&opts.DatabaseURL, "database-url", "", "URL of the database to benchmark (should look like `postgres://...`") - cmd.Flags().DurationVar(&opts.Duration, "duration", 0, "duration after which to stop benchmark, accepting Go-style durations like 1m, 5m30s") - cmd.Flags().IntVarP(&opts.NumTotalJobs, "num-total-jobs", "n", 0, "number of jobs to insert before starting and which are worked down until finish") - mustMarkFlagRequired(cmd, "database-url") - rootCmd.AddCommand(cmd) - } - - // migrate-down and migrate-up share a set of options, so this is a way of - // plugging in all the right flags to both so options and docstrings stay - // consistent. - addMigrateFlags := func(cmd *cobra.Command, opts *migrateOpts) { - cmd.Flags().StringVar(&opts.DatabaseURL, "database-url", "", "URL of the database to migrate (should look like `postgres://...`") - cmd.Flags().BoolVar(&opts.DryRun, "dry-run", false, "print information on migrations, but don't apply them") - cmd.Flags().IntVar(&opts.MaxSteps, "max-steps", 0, "maximum number of steps to migrate") - cmd.Flags().BoolVar(&opts.ShowSQL, "show-sql", false, "show SQL of each migration") - cmd.Flags().IntVar(&opts.TargetVersion, "target-version", 0, "target version to migrate to (final state includes this version, but none after it)") - mustMarkFlagRequired(cmd, "database-url") - } - - // migrate-down - { - var opts migrateOpts - - cmd := &cobra.Command{ - Use: "migrate-down", - Short: "Run River schema down migrations", - Long: strings.TrimSpace(` -Run down migrations to reverse the River database schema changes. - -Defaults to running a single down migration. This behavior can be changed with ---max-steps or --target-version. - -SQL being run can be output using --show-sql, and executing real database -operations can be prevented with --dry-run. Combine --show-sql and --dry-run to -dump prospective migrations that would be applied to stdout. - `), - Run: func(cmd *cobra.Command, args []string) { - execHandlingError(func() (bool, error) { return migrateDown(ctx, makeLogger(), os.Stdout, &opts) }) - }, - } - addMigrateFlags(cmd, &opts) - rootCmd.AddCommand(cmd) - } - - // migrate-get - { - var opts migrateGetOpts - - cmd := &cobra.Command{ - Use: "migrate-get", - Short: "Get SQL for specific River migration", - Long: strings.TrimSpace(` -Retrieve SQL for a single migration version. This command is aimed at cases -where using River's internal migration framework isn't desirable by allowing -migration SQL to be dumped for use elsewhere. - -Specify a version with --version, and one of --down or --up: - - river migrate-get --version 3 --up > river3.up.sql - river migrate-get --version 3 --down > river3.down.sql - -Can also take multiple versions by separating them with commas or passing ---version multiple times: - - river migrate-get --version 1,2,3 --up > river.up.sql - river migrate-get --version 3,2,1 --down > river.down.sql - -Or use --all to print all known migrations in either direction. Often used in -conjunction with --exclude-version 1 to exclude the tables for River's migration -framework, which aren't necessary if using an external framework: - - river migrate-get --all --exclude-version 1 --up > river_all.up.sql - river migrate-get --all --exclude-version 1 --down > river_all.down.sql - `), - Run: func(cmd *cobra.Command, args []string) { - execHandlingError(func() (bool, error) { return migrateGet(ctx, makeLogger(), os.Stdout, &opts) }) - }, - } - cmd.Flags().BoolVar(&opts.All, "all", false, "print all migrations; down migrations are printed in descending order") - cmd.Flags().BoolVar(&opts.Down, "down", false, "print down migration") - cmd.Flags().IntSliceVar(&opts.ExcludeVersion, "exclude-version", nil, "exclude version(s), usually version 1, containing River's migration tables") - cmd.Flags().BoolVar(&opts.Up, "up", false, "print up migration") - cmd.Flags().IntSliceVar(&opts.Version, "version", nil, "version(s) to print (can be multiple versions)") - cmd.MarkFlagsMutuallyExclusive("all", "version") - cmd.MarkFlagsOneRequired("all", "version") - cmd.MarkFlagsMutuallyExclusive("down", "up") - cmd.MarkFlagsOneRequired("down", "up") - rootCmd.AddCommand(cmd) - } - - // migrate-up - { - var opts migrateOpts - - cmd := &cobra.Command{ - Use: "migrate-up", - Short: "Run River schema up migrations", - Long: strings.TrimSpace(` -Run up migrations to raise the database schema necessary to run River. +type DriverProcurer struct{} -Defaults to running all up migrations that aren't yet run. This behavior can be -restricted with --max-steps or --target-version. - -SQL being run can be output using --show-sql, and executing real database -operations can be prevented with --dry-run. Combine --show-sql and --dry-run to -dump prospective migrations that would be applied to stdout. - `), - Run: func(cmd *cobra.Command, args []string) { - execHandlingError(func() (bool, error) { return migrateUp(ctx, makeLogger(), os.Stdout, &opts) }) - }, - } - addMigrateFlags(cmd, &opts) - rootCmd.AddCommand(cmd) - } - - // validate - { - var opts validateOpts - - cmd := &cobra.Command{ - Use: "validate", - Short: "Validate River schema", - Long: strings.TrimSpace(` -Validates the current River schema, exiting with a non-zero status in case there -are outstanding migrations that still need to be run. +func (p *DriverProcurer) ProcurePgxV5(pool *pgxpool.Pool) riverdriver.Driver[pgx.Tx] { + return riverpgxv5.New(pool) +} -Can be paired with river migrate-up --dry-run --show-sql to dump information on -migrations that need to be run, but without running them. - `), - Run: func(cmd *cobra.Command, args []string) { - execHandlingError(func() (bool, error) { return validate(ctx, makeLogger(), os.Stdout, &opts) }) - }, - } - cmd.Flags().StringVar(&opts.DatabaseURL, "database-url", "", "URL of the database to validate (should look like `postgres://...`") - mustMarkFlagRequired(cmd, "database-url") - rootCmd.AddCommand(cmd) - } +func main() { + cli := rivercli.NewCLI(&DriverProcurer{}) - if err := rootCmd.Execute(); err != nil { + if err := cli.BaseCommandSet().Execute(); err != nil { // Cobra will already print an error on problems like an unknown command // or missing required flag. Set an exit status of 1 on error, but don't // print it again. os.Exit(1) } } - -func openDBPool(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) { - const ( - defaultIdleInTransactionSessionTimeout = 11 * time.Second // should be greater than statement timeout because statements count towards idle-in-transaction - defaultStatementTimeout = 10 * time.Second - ) - - pgxConfig, err := pgxpool.ParseConfig(databaseURL) - if err != nil { - return nil, fmt.Errorf("error parsing database URL: %w", err) - } - - setParamIfUnset(pgxConfig.ConnConfig.RuntimeParams, "application_name", "river CLI") - setParamIfUnset(pgxConfig.ConnConfig.RuntimeParams, "idle_in_transaction_session_timeout", strconv.Itoa(int(defaultIdleInTransactionSessionTimeout.Milliseconds()))) - setParamIfUnset(pgxConfig.ConnConfig.RuntimeParams, "statement_timeout", strconv.Itoa(int(defaultStatementTimeout.Milliseconds()))) - - dbPool, err := pgxpool.NewWithConfig(ctx, pgxConfig) - if err != nil { - return nil, fmt.Errorf("error connecting to database: %w", err) - } - - return dbPool, nil -} - -// Sets a parameter in a parameter map (aimed at a Postgres connection -// configuration map), but only if that parameter wasn't already set. -func setParamIfUnset(runtimeParams map[string]string, name, val string) { - if currentVal := runtimeParams[name]; currentVal != "" { - return - } - - runtimeParams[name] = val -} - -type benchOpts struct { - DatabaseURL string - Debug bool - Duration time.Duration - NumTotalJobs int - Verbose bool -} - -func (o *benchOpts) validate() error { - if o.DatabaseURL == "" { - return errors.New("database URL cannot be empty") - } - - return nil -} - -func bench(ctx context.Context, logger *slog.Logger, _ io.Writer, opts *benchOpts) (bool, error) { - if err := opts.validate(); err != nil { - return false, err - } - - dbPool, err := openDBPool(ctx, opts.DatabaseURL) - if err != nil { - return false, err - } - defer dbPool.Close() - - benchmarker := riverbench.NewBenchmarker(riverpgxv5.New(dbPool), logger, opts.Duration, opts.NumTotalJobs) - - if err := benchmarker.Run(ctx); err != nil { - return false, err - } - - return true, nil -} - -type migrateOpts struct { - DatabaseURL string - DryRun bool - ShowSQL bool - MaxSteps int - TargetVersion int -} - -func (o *migrateOpts) validate() error { - if o.DatabaseURL == "" { - return errors.New("database URL cannot be empty") - } - - return nil -} - -func migrateDown(ctx context.Context, logger *slog.Logger, out io.Writer, opts *migrateOpts) (bool, error) { - if err := opts.validate(); err != nil { - return false, err - } - - // Default to applying only one migration maximum on the down direction. - if opts.MaxSteps == 0 && opts.TargetVersion == 0 { - opts.MaxSteps = 1 - } - - dbPool, err := openDBPool(ctx, opts.DatabaseURL) - if err != nil { - return false, err - } - defer dbPool.Close() - - migrator := rivermigrate.New(riverpgxv5.New(dbPool), &rivermigrate.Config{Logger: logger}) - - res, err := migrator.Migrate(ctx, rivermigrate.DirectionDown, &rivermigrate.MigrateOpts{ - DryRun: opts.DryRun, - MaxSteps: opts.MaxSteps, - TargetVersion: opts.TargetVersion, - }) - if err != nil { - return false, err - } - - migratePrintResult(out, opts, res, rivermigrate.DirectionDown) - - return true, nil -} - -func migratePrintResult(out io.Writer, opts *migrateOpts, res *rivermigrate.MigrateResult, direction rivermigrate.Direction) { - if len(res.Versions) < 1 { - fmt.Fprintf(out, "no migrations to apply\n") - return - } - - for _, migrateVersion := range res.Versions { - if opts.DryRun { - fmt.Fprintf(out, "migration %03d [%s] [DRY RUN]\n", migrateVersion.Version, direction) - } else { - fmt.Fprintf(out, "applied migration %03d [%s] [%s]\n", migrateVersion.Version, direction, migrateVersion.Duration) - } - - if opts.ShowSQL { - fmt.Fprintf(out, "%s\n", strings.Repeat("-", 80)) - fmt.Fprintf(out, "%s\n", migrationComment(migrateVersion.Version, direction)) - fmt.Fprintf(out, "%s\n\n", strings.TrimSpace(migrateVersion.SQL)) - } - } - - // Only prints if more steps than available were requested. - if opts.MaxSteps > 0 && len(res.Versions) < opts.MaxSteps { - fmt.Fprintf(out, "no more migrations to apply\n") - } -} - -// An informational comment that's tagged on top of any migration's SQL to help -// attribute what it is for when it's copied elsewhere like other migration -// frameworks. -func migrationComment(version int, direction rivermigrate.Direction) string { - return fmt.Sprintf("-- River migration %03d [%s]", version, direction) -} - -type migrateGetOpts struct { - All bool - Down bool - ExcludeVersion []int - Up bool - Version []int -} - -func migrateGet(_ context.Context, logger *slog.Logger, out io.Writer, opts *migrateGetOpts) (bool, error) { - migrator := rivermigrate.New(riverpgxv5.New(nil), &rivermigrate.Config{Logger: logger}) - - var migrations []rivermigrate.Migration - if opts.All { - migrations = migrator.AllVersions() - if opts.Down { - slices.Reverse(migrations) - } - } else { - for _, version := range opts.Version { - migration, err := migrator.GetVersion(version) - if err != nil { - return false, err - } - - migrations = append(migrations, migration) - } - } - - var printedOne bool - - for _, migration := range migrations { - if slices.Contains(opts.ExcludeVersion, migration.Version) { - continue - } - - // print newlines between multiple versions - if printedOne { - fmt.Fprintf(out, "\n") - } - - var ( - direction rivermigrate.Direction - sql string - ) - switch { - case opts.Down: - direction = rivermigrate.DirectionDown - sql = migration.SQLDown - case opts.Up: - direction = rivermigrate.DirectionUp - sql = migration.SQLUp - } - - printedOne = true - fmt.Fprintf(out, "%s\n", migrationComment(migration.Version, direction)) - fmt.Fprintf(out, "%s\n", strings.TrimSpace(sql)) - } - - return true, nil -} - -func migrateUp(ctx context.Context, logger *slog.Logger, out io.Writer, opts *migrateOpts) (bool, error) { - if err := opts.validate(); err != nil { - return false, err - } - - dbPool, err := openDBPool(ctx, opts.DatabaseURL) - if err != nil { - return false, err - } - defer dbPool.Close() - - migrator := rivermigrate.New(riverpgxv5.New(dbPool), &rivermigrate.Config{Logger: logger}) - - res, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, &rivermigrate.MigrateOpts{ - DryRun: opts.DryRun, - MaxSteps: opts.MaxSteps, - TargetVersion: opts.TargetVersion, - }) - if err != nil { - return false, err - } - - migratePrintResult(out, opts, res, rivermigrate.DirectionUp) - - return true, nil -} - -type validateOpts struct { - DatabaseURL string -} - -func (o *validateOpts) validate() error { - if o.DatabaseURL == "" { - return errors.New("database URL cannot be empty") - } - - return nil -} - -func validate(ctx context.Context, logger *slog.Logger, _ io.Writer, opts *validateOpts) (bool, error) { - if err := opts.validate(); err != nil { - return false, err - } - - dbPool, err := openDBPool(ctx, opts.DatabaseURL) - if err != nil { - return false, err - } - defer dbPool.Close() - - migrator := rivermigrate.New(riverpgxv5.New(dbPool), &rivermigrate.Config{Logger: logger}) - - res, err := migrator.Validate(ctx) - if err != nil { - return false, err - } - - return res.OK, nil -} diff --git a/cmd/river/riverbench/river_bench.go b/cmd/river/riverbench/river_bench.go index 638dcd8f..85776799 100644 --- a/cmd/river/riverbench/river_bench.go +++ b/cmd/river/riverbench/river_bench.go @@ -17,26 +17,22 @@ import ( ) type Benchmarker[TTx any] struct { - driver riverdriver.Driver[TTx] // database pool wrapped in River driver - duration time.Duration // duration to run when running or a duration - logger *slog.Logger // logger, also injected to client - name string // name of the service for logging purposes - numTotalJobs int // total number of jobs to work when in burn down mode + driver riverdriver.Driver[TTx] // database pool wrapped in River driver + logger *slog.Logger // logger, also injected to client + name string // name of the service for logging purposes } -func NewBenchmarker[TTx any](driver riverdriver.Driver[TTx], logger *slog.Logger, duration time.Duration, numTotalJobs int) *Benchmarker[TTx] { +func NewBenchmarker[TTx any](driver riverdriver.Driver[TTx], logger *slog.Logger) *Benchmarker[TTx] { return &Benchmarker[TTx]{ - driver: driver, - duration: duration, - logger: logger, - name: "Benchmarker", - numTotalJobs: numTotalJobs, + driver: driver, + logger: logger, + name: "Benchmarker", } } // Run starts the benchmarking loop. Stops upon receiving SIGINT/SIGTERM, or // when reaching maximum configured run duration. -func (b *Benchmarker[TTx]) Run(ctx context.Context) error { +func (b *Benchmarker[TTx]) Run(ctx context.Context, duration time.Duration, numTotalJobs int) error { var ( lastJobWorkedAt time.Time numJobsInserted atomic.Int64 @@ -206,8 +202,8 @@ func (b *Benchmarker[TTx]) Run(ctx context.Context) error { minJobsReady := make(chan struct{}) - if b.numTotalJobs != 0 { - b.insertJobs(ctx, client, minJobsReady, &numJobsInserted, &numJobsLeft, shutdown) + if numTotalJobs != 0 { + b.insertJobs(ctx, client, minJobsReady, &numJobsInserted, &numJobsLeft, numTotalJobs, shutdown) } else { insertJobsFinished := make(chan struct{}) defer func() { <-insertJobsFinished }() @@ -271,7 +267,7 @@ func (b *Benchmarker[TTx]) Run(ctx context.Context) error { for numIterations := 0; ; numIterations++ { // Use iterations multiplied by period time instead of actual elapsed // time to allow a precise, predictable run duration to be specified. - if b.duration != 0 && time.Duration(numIterations)*iterationPeriod >= b.duration { + if duration != 0 && time.Duration(numIterations)*iterationPeriod >= duration { return nil } @@ -298,7 +294,7 @@ func (b *Benchmarker[TTx]) Run(ctx context.Context) error { // If working in the mode where we're burning jobs down and there are no // jobs left, end. - if b.numTotalJobs != 0 && numJobsLeft.Load() < 1 { + if numTotalJobs != 0 && numJobsLeft.Load() < 1 { return nil } @@ -328,6 +324,7 @@ func (b *Benchmarker[TTx]) insertJobs( minJobsReady chan struct{}, numJobsInserted *atomic.Int64, numJobsLeft *atomic.Int64, + numTotalJobs int, shutdown chan struct{}, ) { defer close(minJobsReady) @@ -353,7 +350,7 @@ func (b *Benchmarker[TTx]) insertJobs( insertParamsBatch[i].Args = jobArgsBatch[i] } - numLeft := b.numTotalJobs - numInsertedThisRound + numLeft := numTotalJobs - numInsertedThisRound if numLeft < insertBatchSize { insertParamsBatch = insertParamsBatch[0:numLeft] } @@ -367,7 +364,7 @@ func (b *Benchmarker[TTx]) insertJobs( numJobsLeft.Add(int64(len(insertParamsBatch))) numInsertedThisRound += len(insertParamsBatch) - if numJobsLeft.Load() >= int64(b.numTotalJobs) { + if numJobsLeft.Load() >= int64(numTotalJobs) { b.logger.InfoContext(ctx, b.name+": Finished inserting jobs", "duration", time.Since(start), "num_inserted", numInsertedThisRound) return diff --git a/cmd/river/rivercli/command.go b/cmd/river/rivercli/command.go new file mode 100644 index 00000000..b2ee032d --- /dev/null +++ b/cmd/river/rivercli/command.go @@ -0,0 +1,154 @@ +package rivercli + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "os" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river/cmd/river/riverbench" + "github.com/riverqueue/river/rivermigrate" +) + +// BenchmarkerInterface is an interface to a Benchmarker. Its reason for +// existence is to wrap a benchmarker to strip it of its generic parameter, +// letting us pass it around without having to know the transaction type. +type BenchmarkerInterface interface { + Run(ctx context.Context, duration time.Duration, numTotalJobs int) error +} + +// MigratorInterface is an interface to a Migrator. Its reason for existence is +// to wrap a migrator to strip it of its generic parameter, letting us pass it +// around without having to know the transaction type. +type MigratorInterface interface { + AllVersions() []rivermigrate.Migration + GetVersion(version int) (rivermigrate.Migration, error) + Migrate(ctx context.Context, direction rivermigrate.Direction, opts *rivermigrate.MigrateOpts) (*rivermigrate.MigrateResult, error) + Validate(ctx context.Context) (*rivermigrate.ValidateResult, error) +} + +// Command is an interface to a River CLI subcommand. Commands generally only +// implement a Run function, and get the rest of the implementation by embedding +// CommandBase. +type Command[TOpts CommandOpts] interface { + Run(ctx context.Context, opts TOpts) (bool, error) + SetCommandBase(b *CommandBase) +} + +// CommandBase provides common facilities for a River CLI command. It's +// generally embedded on the struct of a command. +type CommandBase struct { + DriverProcurer DriverProcurer + Logger *slog.Logger + Out io.Writer + + GetBenchmarker func() BenchmarkerInterface + GetMigrator func(config *rivermigrate.Config) MigratorInterface +} + +func (b *CommandBase) SetCommandBase(base *CommandBase) { + *b = *base +} + +// CommandOpts are options for a command options. It makes sure that options +// provide a way of validating themselves. +type CommandOpts interface { + Validate() error +} + +// RunCommandBundle is a bundle of utilities for RunCommand. +type RunCommandBundle struct { + DatabaseURL *string + DriverProcurer DriverProcurer + Logger *slog.Logger +} + +// RunCommand bootstraps and runs a River CLI subcommand. +func RunCommand[TOpts CommandOpts](ctx context.Context, bundle *RunCommandBundle, command Command[TOpts], opts TOpts) { + procureAndRun := func() (bool, error) { + if err := opts.Validate(); err != nil { + return false, err + } + + commandBase := &CommandBase{ + DriverProcurer: bundle.DriverProcurer, + Logger: bundle.Logger, + Out: os.Stdout, + } + + switch { + // If database URL is still nil after Validate check, then assume this + // command doesn't take one. + case bundle.DatabaseURL == nil: + commandBase.GetBenchmarker = func() BenchmarkerInterface { panic("databaseURL was not set") } + commandBase.GetMigrator = func(config *rivermigrate.Config) MigratorInterface { panic("databaseURL was not set") } + + case strings.HasPrefix(*bundle.DatabaseURL, "postgres://"): + dbPool, err := openPgxV5DBPool(ctx, *bundle.DatabaseURL) + if err != nil { + return false, err + } + defer dbPool.Close() + + driver := bundle.DriverProcurer.ProcurePgxV5(dbPool) + + commandBase.GetBenchmarker = func() BenchmarkerInterface { return riverbench.NewBenchmarker(driver, commandBase.Logger) } + commandBase.GetMigrator = func(config *rivermigrate.Config) MigratorInterface { return rivermigrate.New(driver, config) } + + default: + return false, errors.New("unsupport database URL; try one with a prefix of `postgres://...`") + } + + command.SetCommandBase(commandBase) + + return command.Run(ctx, opts) + } + + ok, err := procureAndRun() + if err != nil { + fmt.Fprintf(os.Stderr, "failed: %s\n", err) + } + if err != nil || !ok { + os.Exit(1) + } +} + +func openPgxV5DBPool(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) { + const ( + defaultIdleInTransactionSessionTimeout = 11 * time.Second // should be greater than statement timeout because statements count towards idle-in-transaction + defaultStatementTimeout = 10 * time.Second + ) + + pgxConfig, err := pgxpool.ParseConfig(databaseURL) + if err != nil { + return nil, fmt.Errorf("error parsing database URL: %w", err) + } + + // Sets a parameter in a parameter map (aimed at a Postgres connection + // configuration map), but only if that parameter wasn't already set. + setParamIfUnset := func(runtimeParams map[string]string, name, val string) { + if currentVal := runtimeParams[name]; currentVal != "" { + return + } + + runtimeParams[name] = val + } + + setParamIfUnset(pgxConfig.ConnConfig.RuntimeParams, "application_name", "river CLI") + setParamIfUnset(pgxConfig.ConnConfig.RuntimeParams, "idle_in_transaction_session_timeout", strconv.Itoa(int(defaultIdleInTransactionSessionTimeout.Milliseconds()))) + setParamIfUnset(pgxConfig.ConnConfig.RuntimeParams, "statement_timeout", strconv.Itoa(int(defaultStatementTimeout.Milliseconds()))) + + dbPool, err := pgxpool.NewWithConfig(ctx, pgxConfig) + if err != nil { + return nil, fmt.Errorf("error connecting to database: %w", err) + } + + return dbPool, nil +} diff --git a/cmd/river/rivercli/river_cli.go b/cmd/river/rivercli/river_cli.go new file mode 100644 index 00000000..48c9e4c6 --- /dev/null +++ b/cmd/river/rivercli/river_cli.go @@ -0,0 +1,480 @@ +// Package rivercli provides an implementation for the River CLI. +// +// This package is largely for internal use and doesn't provide the same API +// guarantees as the main River modules. Breaking API changes will be made +// without warning. +package rivercli + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "os" + "slices" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/lmittmann/tint" + "github.com/spf13/cobra" + + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/rivermigrate" +) + +// DriverProcurer is an interface that provides a way of procuring drivers for +// various supported databases. +type DriverProcurer interface { + ProcurePgxV5(pool *pgxpool.Pool) riverdriver.Driver[pgx.Tx] +} + +// CLI provides a common base of commands for the River CLI. +type CLI struct { + driverProcurer DriverProcurer +} + +func NewCLI(driverProcurer DriverProcurer) *CLI { + return &CLI{ + driverProcurer: driverProcurer, + } +} + +// BaseCommandSet provides a base River CLI command set which may be further +// augmented with additional commands. +func (c *CLI) BaseCommandSet() *cobra.Command { + var rootOpts struct { + Debug bool + Verbose bool + } + rootCmd := &cobra.Command{ + Use: "river", + Short: "Provides command line facilities for the River job queue", + Long: strings.TrimSpace(` +Provides command line facilities for the River job queue. + `), + Run: func(cmd *cobra.Command, args []string) { + _ = cmd.Usage() + }, + } + rootCmd.PersistentFlags().BoolVar(&rootOpts.Debug, "debug", false, "output maximum logging verbosity (debug level)") + rootCmd.PersistentFlags().BoolVarP(&rootOpts.Verbose, "verbose", "v", false, "output additional logging verbosity (info level)") + rootCmd.MarkFlagsMutuallyExclusive("debug", "verbose") + + ctx := context.Background() + + makeLogger := func() *slog.Logger { + switch { + case rootOpts.Debug: + return slog.New(tint.NewHandler(os.Stdout, &tint.Options{Level: slog.LevelDebug})) + case rootOpts.Verbose: + return slog.New(tint.NewHandler(os.Stdout, nil)) + default: + return slog.New(tint.NewHandler(os.Stdout, &tint.Options{Level: slog.LevelWarn})) + } + } + + // Make a bundle for RunCommand. Takes a database URL pointer because not every command is required to take a database URL. + makeCommandBundle := func(databaseURL *string) *RunCommandBundle { + return &RunCommandBundle{ + DatabaseURL: databaseURL, + DriverProcurer: c.driverProcurer, + Logger: makeLogger(), + } + } + + mustMarkFlagRequired := func(cmd *cobra.Command, name string) { + // We just panic here because this will never happen outside of an error + // in development. + if err := cmd.MarkFlagRequired(name); err != nil { + panic(err) + } + } + + addDatabaseURLFlag := func(cmd *cobra.Command, databaseURL *string) { + cmd.Flags().StringVar(databaseURL, "database-url", "", "URL of the database to benchmark (should look like `postgres://...`") + mustMarkFlagRequired(cmd, "database-url") + } + addLineFlag := func(cmd *cobra.Command, line *string) { + cmd.Flags().StringVar(line, "line", "", "migration line to operate on (default: main)") + } + + // bench + { + var opts benchOpts + + cmd := &cobra.Command{ + Use: "bench", + Short: "Run River benchmark", + Long: strings.TrimSpace(` +Run a River benchmark which inserts and works jobs continually, giving a rough +idea of jobs per second and time to work a single job. + +By default, the benchmark will continuously insert and work jobs in perpetuity +until interrupted by SIGINT (Ctrl^C). It can alternatively take a maximum run +duration with --duration, which takes a Go-style duration string like 1m. +Lastly, it can take --num-total-jobs, which inserts the given number of jobs +before starting the client, and works until all jobs are finished. + +The database in --database-url will have its jobs table truncated, so make sure +to use a development database only. + `), + Run: func(cmd *cobra.Command, args []string) { + RunCommand(ctx, makeCommandBundle(&opts.DatabaseURL), &bench{}, &opts) + }, + } + addDatabaseURLFlag(cmd, &opts.DatabaseURL) + cmd.Flags().DurationVar(&opts.Duration, "duration", 0, "duration after which to stop benchmark, accepting Go-style durations like 1m, 5m30s") + cmd.Flags().IntVarP(&opts.NumTotalJobs, "num-total-jobs", "n", 0, "number of jobs to insert before starting and which are worked down until finish") + rootCmd.AddCommand(cmd) + } + + // migrate-down and migrate-up share a set of options, so this is a way of + // plugging in all the right flags to both so options and docstrings stay + // consistent. + addMigrateFlags := func(cmd *cobra.Command, opts *migrateOpts) { + addDatabaseURLFlag(cmd, &opts.DatabaseURL) + cmd.Flags().BoolVar(&opts.DryRun, "dry-run", false, "print information on migrations, but don't apply them") + cmd.Flags().StringVar(&opts.Line, "line", "", "migration line to operate on (default: main)") + cmd.Flags().IntVar(&opts.MaxSteps, "max-steps", 0, "maximum number of steps to migrate") + cmd.Flags().BoolVar(&opts.ShowSQL, "show-sql", false, "show SQL of each migration") + cmd.Flags().IntVar(&opts.TargetVersion, "target-version", 0, "target version to migrate to (final state includes this version, but none after it)") + } + + // migrate-down + { + var opts migrateOpts + + cmd := &cobra.Command{ + Use: "migrate-down", + Short: "Run River schema down migrations", + Long: strings.TrimSpace(` +Run down migrations to reverse the River database schema changes. + +Defaults to running a single down migration. This behavior can be changed with +--max-steps or --target-version. + +SQL being run can be output using --show-sql, and executing real database +operations can be prevented with --dry-run. Combine --show-sql and --dry-run to +dump prospective migrations that would be applied to stdout. + `), + Run: func(cmd *cobra.Command, args []string) { + RunCommand(ctx, makeCommandBundle(&opts.DatabaseURL), &migrateDown{}, &opts) + }, + } + addMigrateFlags(cmd, &opts) + rootCmd.AddCommand(cmd) + } + + // migrate-get + { + var opts migrateGetOpts + + cmd := &cobra.Command{ + Use: "migrate-get", + Short: "Get SQL for specific River migration", + Long: strings.TrimSpace(` +Retrieve SQL for a single migration version. This command is aimed at cases +where using River's internal migration framework isn't desirable by allowing +migration SQL to be dumped for use elsewhere. + +Specify a version with --version, and one of --down or --up: + + river migrate-get --version 3 --up > river3.up.sql + river migrate-get --version 3 --down > river3.down.sql + +Can also take multiple versions by separating them with commas or passing +--version multiple times: + + river migrate-get --version 1,2,3 --up > river.up.sql + river migrate-get --version 3,2,1 --down > river.down.sql + +Or use --all to print all known migrations in either direction. Often used in +conjunction with --exclude-version 1 to exclude the tables for River's migration +framework, which aren't necessary if using an external framework: + + river migrate-get --all --exclude-version 1 --up > river_all.up.sql + river migrate-get --all --exclude-version 1 --down > river_all.down.sql + `), + Run: func(cmd *cobra.Command, args []string) { + RunCommand(ctx, makeCommandBundle(nil), &migrateGet{}, &opts) + }, + } + cmd.Flags().BoolVar(&opts.All, "all", false, "print all migrations; down migrations are printed in descending order") + cmd.Flags().BoolVar(&opts.Down, "down", false, "print down migration") + cmd.Flags().IntSliceVar(&opts.ExcludeVersion, "exclude-version", nil, "exclude version(s), usually version 1, containing River's migration tables") + addLineFlag(cmd, &opts.Line) + cmd.Flags().BoolVar(&opts.Up, "up", false, "print up migration") + cmd.Flags().IntSliceVar(&opts.Version, "version", nil, "version(s) to print (can be multiple versions)") + cmd.MarkFlagsMutuallyExclusive("all", "version") + cmd.MarkFlagsOneRequired("all", "version") + cmd.MarkFlagsMutuallyExclusive("down", "up") + cmd.MarkFlagsOneRequired("down", "up") + rootCmd.AddCommand(cmd) + } + + // migrate-up + { + var opts migrateOpts + + cmd := &cobra.Command{ + Use: "migrate-up", + Short: "Run River schema up migrations", + Long: strings.TrimSpace(` +Run up migrations to raise the database schema necessary to run River. + +Defaults to running all up migrations that aren't yet run. This behavior can be +restricted with --max-steps or --target-version. + +SQL being run can be output using --show-sql, and executing real database +operations can be prevented with --dry-run. Combine --show-sql and --dry-run to +dump prospective migrations that would be applied to stdout. + `), + Run: func(cmd *cobra.Command, args []string) { + RunCommand(ctx, makeCommandBundle(&opts.DatabaseURL), &migrateUp{}, &opts) + }, + } + addMigrateFlags(cmd, &opts) + rootCmd.AddCommand(cmd) + } + + // validate + { + var opts validateOpts + + cmd := &cobra.Command{ + Use: "validate", + Short: "Validate River schema", + Long: strings.TrimSpace(` +Validates the current River schema, exiting with a non-zero status in case there +are outstanding migrations that still need to be run. + +Can be paired with river migrate-up --dry-run --show-sql to dump information on +migrations that need to be run, but without running them. + `), + Run: func(cmd *cobra.Command, args []string) { + RunCommand(ctx, makeCommandBundle(&opts.DatabaseURL), &validate{}, &opts) + }, + } + addDatabaseURLFlag(cmd, &opts.DatabaseURL) + cmd.Flags().StringVar(&opts.Line, "line", "", "migration line to operate on (default: main)") + rootCmd.AddCommand(cmd) + } + + return rootCmd +} + +type benchOpts struct { + DatabaseURL string + Debug bool + Duration time.Duration + NumTotalJobs int + Verbose bool +} + +func (o *benchOpts) Validate() error { + if o.DatabaseURL == "" { + return errors.New("database URL cannot be empty") + } + + return nil +} + +type bench struct { + CommandBase +} + +func (c *bench) Run(ctx context.Context, opts *benchOpts) (bool, error) { + if err := c.GetBenchmarker().Run(ctx, opts.Duration, opts.NumTotalJobs); err != nil { + return false, err + } + return true, nil +} + +type migrateOpts struct { + DatabaseURL string + DryRun bool + Line string + ShowSQL bool + MaxSteps int + TargetVersion int +} + +func (o *migrateOpts) Validate() error { + if o.DatabaseURL == "" { + return errors.New("database URL cannot be empty") + } + + return nil +} + +type migrateDown struct { + CommandBase +} + +func (c *migrateDown) Run(ctx context.Context, opts *migrateOpts) (bool, error) { + res, err := c.GetMigrator(&rivermigrate.Config{Line: opts.Line, Logger: c.Logger}).Migrate(ctx, rivermigrate.DirectionDown, &rivermigrate.MigrateOpts{ + DryRun: opts.DryRun, + MaxSteps: opts.MaxSteps, + TargetVersion: opts.TargetVersion, + }) + if err != nil { + return false, err + } + + migratePrintResult(c.Out, opts, res, rivermigrate.DirectionDown) + + return true, nil +} + +func migratePrintResult(out io.Writer, opts *migrateOpts, res *rivermigrate.MigrateResult, direction rivermigrate.Direction) { + if len(res.Versions) < 1 { + fmt.Fprintf(out, "no migrations to apply\n") + return + } + + for _, migrateVersion := range res.Versions { + if opts.DryRun { + fmt.Fprintf(out, "migration %03d [%s] [DRY RUN]\n", migrateVersion.Version, direction) + } else { + fmt.Fprintf(out, "applied migration %03d [%s] [%s]\n", migrateVersion.Version, direction, migrateVersion.Duration) + } + + if opts.ShowSQL { + fmt.Fprintf(out, "%s\n", strings.Repeat("-", 80)) + fmt.Fprintf(out, "%s\n", migrationComment(migrateVersion.Version, direction)) + fmt.Fprintf(out, "%s\n\n", strings.TrimSpace(migrateVersion.SQL)) + } + } + + // Only prints if more steps than available were requested. + if opts.MaxSteps > 0 && len(res.Versions) < opts.MaxSteps { + fmt.Fprintf(out, "no more migrations to apply\n") + } +} + +// An informational comment that's tagged on top of any migration's SQL to help +// attribute what it is for when it's copied elsewhere like other migration +// frameworks. +func migrationComment(version int, direction rivermigrate.Direction) string { + return fmt.Sprintf("-- River migration %03d [%s]", version, direction) +} + +type migrateGetOpts struct { + All bool + Down bool + ExcludeVersion []int + Line string + Up bool + Version []int +} + +func (o *migrateGetOpts) Validate() error { return nil } + +type migrateGet struct { + CommandBase +} + +func (c *migrateGet) Run(_ context.Context, opts *migrateGetOpts) (bool, error) { + // We'll need to have a way of using an alternate driver if support for + // other databases is added in the future. Unlike other migrate commands, + // this one doesn't take a `--database-url`, so we'd need a way of + // detecting the database type. + migrator := rivermigrate.New(c.DriverProcurer.ProcurePgxV5(nil), &rivermigrate.Config{Line: opts.Line, Logger: c.Logger}) + + var migrations []rivermigrate.Migration + if opts.All { + migrations = migrator.AllVersions() + if opts.Down { + slices.Reverse(migrations) + } + } else { + for _, version := range opts.Version { + migration, err := migrator.GetVersion(version) + if err != nil { + return false, err + } + + migrations = append(migrations, migration) + } + } + + var printedOne bool + + for _, migration := range migrations { + if slices.Contains(opts.ExcludeVersion, migration.Version) { + continue + } + + // print newlines between multiple versions + if printedOne { + fmt.Fprintf(c.Out, "\n") + } + + var ( + direction rivermigrate.Direction + sql string + ) + switch { + case opts.Down: + direction = rivermigrate.DirectionDown + sql = migration.SQLDown + case opts.Up: + direction = rivermigrate.DirectionUp + sql = migration.SQLUp + } + + printedOne = true + fmt.Fprintf(c.Out, "%s\n", migrationComment(migration.Version, direction)) + fmt.Fprintf(c.Out, "%s\n", strings.TrimSpace(sql)) + } + + return true, nil +} + +type migrateUp struct { + CommandBase +} + +func (c *migrateUp) Run(ctx context.Context, opts *migrateOpts) (bool, error) { + res, err := c.GetMigrator(&rivermigrate.Config{Line: opts.Line, Logger: c.Logger}).Migrate(ctx, rivermigrate.DirectionUp, &rivermigrate.MigrateOpts{ + DryRun: opts.DryRun, + MaxSteps: opts.MaxSteps, + TargetVersion: opts.TargetVersion, + }) + if err != nil { + return false, err + } + + migratePrintResult(c.Out, opts, res, rivermigrate.DirectionUp) + + return true, nil +} + +type validateOpts struct { + DatabaseURL string + Line string +} + +func (o *validateOpts) Validate() error { + if o.DatabaseURL == "" { + return errors.New("database URL cannot be empty") + } + + return nil +} + +type validate struct { + CommandBase +} + +func (c *validate) Run(ctx context.Context, opts *validateOpts) (bool, error) { + res, err := c.GetMigrator(&rivermigrate.Config{Line: opts.Line, Logger: c.Logger}).Validate(ctx) + if err != nil { + return false, err + } + + return res.OK, nil +} diff --git a/cmd/river/main_test.go b/cmd/river/rivercli/river_cli_test.go similarity index 95% rename from cmd/river/main_test.go rename to cmd/river/rivercli/river_cli_test.go index 7a47fe70..066239e7 100644 --- a/cmd/river/main_test.go +++ b/cmd/river/rivercli/river_cli_test.go @@ -1,4 +1,4 @@ -package main +package rivercli import ( "testing"