Skip to content

Commit

Permalink
Replace MigratorFS with fs.FS
Browse files Browse the repository at this point in the history
  • Loading branch information
jackc committed May 7, 2022
1 parent 372b38c commit af6fc5e
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 91 deletions.
3 changes: 2 additions & 1 deletion README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ is still available through RubyGems and the source code is on the ruby branch.

## Unreleased 2

* Remove deprecated env access syntax in config file.
* Remove deprecated env access syntax in config file
* Replace MigratorFS interface with fs.FS

## 1.13.0 (April 21, 2022)

Expand Down
16 changes: 8 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ func NewMigration(cmd *cobra.Command, args []string) {
name := args[0]

migrationsPath := cliOptions.migrationsPath
migrations, err := migrate.FindMigrations(migrationsPath)
migrations, err := migrate.FindMigrations(os.DirFS(migrationsPath))
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading migrations:\n %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -447,7 +447,7 @@ func Migrate(cmd *cobra.Command, args []string) {
migrator.Data = config.Data

migrationsPath := cliOptions.migrationsPath
err = migrator.LoadMigrations(migrationsPath)
err = migrator.LoadMigrations(os.DirFS(migrationsPath))
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading migrations:\n %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -534,7 +534,7 @@ func Migrate(cmd *cobra.Command, args []string) {
func InstallCode(cmd *cobra.Command, args []string) {
path := args[0]

codePackage, err := migrate.LoadCodePackage(path)
codePackage, err := migrate.LoadCodePackage(os.DirFS(path))
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load code package:\n %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -581,7 +581,7 @@ func InstallCode(cmd *cobra.Command, args []string) {
func CompileCode(cmd *cobra.Command, args []string) {
path := args[0]

codePackage, err := migrate.LoadCodePackage(path)
codePackage, err := migrate.LoadCodePackage(os.DirFS(path))
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load code package:\n %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -611,14 +611,14 @@ func CompileCode(cmd *cobra.Command, args []string) {
func SnapshotCode(cmd *cobra.Command, args []string) {
path := args[0]

_, err := migrate.LoadCodePackage(path)
_, err := migrate.LoadCodePackage(os.DirFS(path))
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load code package:\n %v\n", err)
os.Exit(1)
}

migrationsPath := cliOptions.migrationsPath
migrations, err := migrate.FindMigrations(migrationsPath)
migrations, err := migrate.FindMigrations(os.DirFS(migrationsPath))
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading migrations:\n %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -664,7 +664,7 @@ func Status(cmd *cobra.Command, args []string) {
migrator.Data = config.Data

migrationsPath := cliOptions.migrationsPath
err = migrator.LoadMigrations(migrationsPath)
err = migrator.LoadMigrations(os.DirFS(migrationsPath))
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading migrations:\n %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -696,7 +696,7 @@ func Status(cmd *cobra.Command, args []string) {

func RenumberStart(cmd *cobra.Command, args []string) {
migrationsPath := cliOptions.migrationsPath
migrations, err := migrate.FindMigrations(migrationsPath)
migrations, err := migrate.FindMigrations(os.DirFS(migrationsPath))
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading migrations:\n %v\n", err)
os.Exit(1)
Expand Down
37 changes: 17 additions & 20 deletions migrate/code.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"context"
"errors"
"fmt"
"io/fs"
"path/filepath"
"strings"
"text/template"

"github.com/Masterminds/sprig"
Expand All @@ -28,55 +28,56 @@ func (cp *CodePackage) Eval(data map[string]interface{}) (string, error) {
return buf.String(), nil
}

func findCodeFiles(dirname string, fs MigratorFS) ([]string, error) {
dirname = strings.TrimRight(dirname, string(filepath.Separator))

entries, err := fs.ReadDir(dirname)
func findCodeFiles(fsys fs.FS) ([]string, error) {
entries, err := fs.ReadDir(fsys, ".")
if err != nil {
return nil, err
}

var results []string

for _, e := range entries {
ePath := filepath.Join(dirname, e.Name())
if e.IsDir() {
paths, err := findCodeFiles(ePath, fs)
subfs, err := fs.Sub(fsys, e.Name())
if err != nil {
return nil, err
}
results = append(results, paths...)
paths, err := findCodeFiles(subfs)
if err != nil {
return nil, err
}

for _, p := range paths {
results = append(results, filepath.Join(e.Name(), p))
}
} else {
match, err := filepath.Match("*.sql", e.Name())
if err != nil {
return nil, fmt.Errorf("impossible filepath.Match error %w", err)
}
if match {
results = append(results, ePath)
results = append(results, e.Name())
}
}
}

return results, nil
}

func LoadCodePackageEx(path string, fs MigratorFS) (*CodePackage, error) {
path = strings.TrimRight(path, string(filepath.Separator))

func LoadCodePackage(fsys fs.FS) (*CodePackage, error) {
mainTmpl := template.New("main").Funcs(sprig.TxtFuncMap())
sqlPaths, err := findCodeFiles(path, fs)
sqlPaths, err := findCodeFiles(fsys)
if err != nil {
return nil, err
}

for _, p := range sqlPaths {
body, err := fs.ReadFile(p)
body, err := fs.ReadFile(fsys, p)
if err != nil {
return nil, err
}

name := strings.Replace(p, path+string(filepath.Separator), "", 1)
_, err = mainTmpl.New(name).Parse(string(body))
_, err = mainTmpl.New(p).Parse(string(body))
if err != nil {
return nil, err
}
Expand All @@ -92,10 +93,6 @@ func LoadCodePackageEx(path string, fs MigratorFS) (*CodePackage, error) {
return codePackage, nil
}

func LoadCodePackage(path string) (*CodePackage, error) {
return LoadCodePackageEx(path, defaultMigratorFS{})
}

func InstallCodePackage(ctx context.Context, conn *pgx.Conn, mergeData map[string]interface{}, codePackage *CodePackage) (err error) {
sql, err := codePackage.Eval(mergeData)
if err != nil {
Expand Down
7 changes: 4 additions & 3 deletions migrate/code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package migrate_test

import (
"context"
"os"
"testing"

"github.com/jackc/tern/migrate"
Expand All @@ -10,18 +11,18 @@ import (
)

func TestLoadCodePackage(t *testing.T) {
codePackage, err := migrate.LoadCodePackage("testdata/code")
codePackage, err := migrate.LoadCodePackage(os.DirFS("testdata/code"))
assert.NoError(t, err)
assert.NotNil(t, codePackage)
}

func TestLoadCodePackageNotCodePackage(t *testing.T) {
codePackage, err := migrate.LoadCodePackage("testdata/sample")
codePackage, err := migrate.LoadCodePackage(os.DirFS("testdata/sample"))
assert.EqualError(t, err, "install.sql not found")
assert.Nil(t, codePackage)
}
func TestInstallCodePackage(t *testing.T) {
codePackage, err := migrate.LoadCodePackage("testdata/code")
codePackage, err := migrate.LoadCodePackage(os.DirFS("testdata/code"))
require.NoError(t, err)
require.NotNil(t, codePackage)

Expand Down
67 changes: 21 additions & 46 deletions migrate/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"context"
"errors"
"fmt"
"io/ioutil"
"io/fs"
"os"
"path/filepath"
"regexp"
Expand Down Expand Up @@ -37,11 +37,10 @@ func (e IrreversibleMigrationError) Error() string {
}

type NoMigrationsFoundError struct {
Path string
}

func (e NoMigrationsFoundError) Error() string {
return fmt.Sprintf("No migrations found at %s", e.Path)
return "migrations not found"
}

type MigrationPgError struct {
Expand Down Expand Up @@ -71,8 +70,8 @@ type Migration struct {
type MigratorOptions struct {
// DisableTx causes the Migrator not to run migrations in a transaction.
DisableTx bool
// MigratorFS is the interface used for collecting the migrations.
MigratorFS MigratorFS
// FileSystem is the interface used for collecting the migrations.
FileSystem fs.FS
}

type Migrator struct {
Expand All @@ -86,7 +85,7 @@ type Migrator struct {

// NewMigrator initializes a new Migrator. It is highly recommended that versionTable be schema qualified.
func NewMigrator(ctx context.Context, conn *pgx.Conn, versionTable string) (m *Migrator, err error) {
return NewMigratorEx(ctx, conn, versionTable, &MigratorOptions{MigratorFS: defaultMigratorFS{}})
return NewMigratorEx(ctx, conn, versionTable, &MigratorOptions{FileSystem: os.DirFS("/")})
}

// NewMigratorEx initializes a new Migrator. It is highly recommended that versionTable be schema qualified.
Expand All @@ -98,30 +97,9 @@ func NewMigratorEx(ctx context.Context, conn *pgx.Conn, versionTable string, opt
return
}

type MigratorFS interface {
ReadDir(dirname string) ([]os.FileInfo, error)
ReadFile(filename string) ([]byte, error)
Glob(pattern string) (matches []string, err error)
}

type defaultMigratorFS struct{}

func (defaultMigratorFS) ReadDir(dirname string) ([]os.FileInfo, error) {
return ioutil.ReadDir(dirname)
}

func (defaultMigratorFS) ReadFile(filename string) ([]byte, error) {
return ioutil.ReadFile(filename)
}

func (defaultMigratorFS) Glob(pattern string) ([]string, error) {
return filepath.Glob(pattern)
}

func FindMigrationsEx(path string, fs MigratorFS) ([]string, error) {
path = strings.TrimRight(path, string(filepath.Separator))

fileInfos, err := fs.ReadDir(path)
// FindMigrations finds all migration files in fsys.
func FindMigrations(fsys fs.FS) ([]string, error) {
fileInfos, err := fs.ReadDir(fsys, ".")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -151,23 +129,21 @@ func FindMigrationsEx(path string, fs MigratorFS) ([]string, error) {
return nil, fmt.Errorf("Missing migration %d", len(paths)+1)
}

paths = append(paths, filepath.Join(path, fi.Name()))
paths = append(paths, fi.Name())
}

return paths, nil
}

func FindMigrations(path string) ([]string, error) {
return FindMigrationsEx(path, defaultMigratorFS{})
}

func (m *Migrator) LoadMigrations(path string) error {
path = strings.TrimRight(path, string(filepath.Separator))

func (m *Migrator) LoadMigrations(fsys fs.FS) error {
mainTmpl := template.New("main").Funcs(sprig.TxtFuncMap()).Funcs(
template.FuncMap{
"install_snapshot": func(name string) (string, error) {
codePackage, err := LoadCodePackageEx(filepath.Join(path, "snapshots", name), m.options.MigratorFS)
codePackageFSys, err := fs.Sub(fsys, filepath.Join("snapshots", name))
if err != nil {
return "", err
}
codePackage, err := LoadCodePackage(codePackageFSys)
if err != nil {
return "", err
}
Expand All @@ -177,35 +153,34 @@ func (m *Migrator) LoadMigrations(path string) error {
},
)

sharedPaths, err := m.options.MigratorFS.Glob(filepath.Join(path, "*", "*.sql"))
sharedPaths, err := fs.Glob(fsys, filepath.Join("*", "*.sql"))
if err != nil {
return err
}

for _, p := range sharedPaths {
body, err := m.options.MigratorFS.ReadFile(p)
body, err := fs.ReadFile(fsys, p)
if err != nil {
return err
}

name := strings.Replace(p, path+string(filepath.Separator), "", 1)
_, err = mainTmpl.New(name).Parse(string(body))
_, err = mainTmpl.New(p).Parse(string(body))
if err != nil {
return err
}
}

paths, err := FindMigrationsEx(path, m.options.MigratorFS)
paths, err := FindMigrations(fsys)
if err != nil {
return err
}

if len(paths) == 0 {
return NoMigrationsFoundError{Path: path}
return NoMigrationsFoundError{}
}

for _, p := range paths {
body, err := m.options.MigratorFS.ReadFile(p)
body, err := fs.ReadFile(fsys, p)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit af6fc5e

Please sign in to comment.