Skip to content

implemented optional logger to log details when applying migrations #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package migrate

import (
"context"
)

type Logger interface {
Info(context.Context, string, ...any)
Warn(context.Context, string, ...any)
Error(context.Context, string, ...any)
}
70 changes: 70 additions & 0 deletions log/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package log

import (
"context"
"fmt"
)

// Level is a type to represent the log level
type Level int

const (
LevelSilent Level = iota
LevelError
LevelWarn
LevelInfo
)

type DefaultLogger struct {
level Level
writer Writer
}

type Writer interface {
Printf(string, ...interface{})
}

// DefaultLogWriter is a default implementation of the Writer interface
// that writes using fmt.Printf
type DefaultLogWriter struct{}

func (w *DefaultLogWriter) Printf(format string, args ...interface{}) {
fmt.Printf(format, args...)
}

// NewDefaultLogger creates a new DefaultLogger with a silent log level
func NewDefaultLogger() *DefaultLogger {
return &DefaultLogger{
level: LevelSilent,
writer: &DefaultLogWriter{},
}
}

func (l *DefaultLogger) WithLevel(level Level) *DefaultLogger {
l.level = level
return l
}

func (l *DefaultLogger) WithWriter(writer Writer) *DefaultLogger {
l.writer = writer
return l
}

func (l *DefaultLogger) Info(_ context.Context, format string, args ...interface{}) {
l.logIfPermittedByLevel(LevelInfo, format, args...)
}

func (l *DefaultLogger) Warn(_ context.Context, format string, args ...interface{}) {
l.logIfPermittedByLevel(LevelWarn, format, args...)
}

func (l *DefaultLogger) Error(_ context.Context, format string, args ...interface{}) {
l.logIfPermittedByLevel(LevelError, format, args...)
}

func (l *DefaultLogger) logIfPermittedByLevel(requiredLevel Level, format string, args ...interface{}) {
if l.level < requiredLevel {
return
}
l.writer.Printf(format, args...)
}
79 changes: 79 additions & 0 deletions log/log_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package log_test

import (
"context"
"fmt"
"testing"

"github.com/rubenv/sql-migrate/log"
)

type mockWriter struct {
logs []string
}

func (mw *mockWriter) Printf(format string, args ...interface{}) {
mw.logs = append(mw.logs, fmt.Sprintf(format, args...))
}

func TestDefaultLoggerWithLevelInfo(t *testing.T) {
mockWriter := &mockWriter{logs: []string{}}

logger := log.NewDefaultLogger().WithLevel(log.LevelInfo).WithWriter(mockWriter)
logger.Info(context.Background(), "This should be logged")
logger.Warn(context.Background(), "This should also be logged")
logger.Error(context.Background(), "This should also be logged")

expectedLogs := []string{
"This should be logged",
"This should also be logged",
"This should also be logged",
}

if len(mockWriter.logs) != len(expectedLogs) {
t.Fatalf("Expected %d logs, got %d", len(expectedLogs), len(mockWriter.logs))
}

for i, expectedLog := range expectedLogs {
if expectedLog != mockWriter.logs[i] {
t.Fatalf("Expected log %d to be %s, got %s", i, expectedLog, mockWriter.logs[i])
}
}
}

func TestDefaultLoggerWithLevelSilent(t *testing.T) {
mockWriter := &mockWriter{logs: []string{}}

logger := log.NewDefaultLogger().WithLevel(log.LevelSilent).WithWriter(mockWriter)
logger.Info(context.Background(), "This should not be logged")
logger.Warn(context.Background(), "This should not be logged")
logger.Error(context.Background(), "This should not be logged")

if len(mockWriter.logs) != 0 {
t.Fatalf("Expected no logs, got %d", len(mockWriter.logs))
}
}

func TestDefaultLoggerWithLevelWarn(t *testing.T) {
mockWriter := &mockWriter{logs: []string{}}

logger := log.NewDefaultLogger().WithLevel(log.LevelWarn).WithWriter(mockWriter)
logger.Info(context.Background(), "This should not be logged")
logger.Warn(context.Background(), "This should be logged")
logger.Error(context.Background(), "This should also be logged")

expectedLogs := []string{
"This should be logged",
"This should also be logged",
}

if len(mockWriter.logs) != len(expectedLogs) {
t.Fatalf("Expected %d logs, got %d", len(expectedLogs), len(mockWriter.logs))
}

for i, expectedLog := range expectedLogs {
if expectedLog != mockWriter.logs[i] {
t.Fatalf("Expected log %d to be %s, got %s", i, expectedLog, mockWriter.logs[i])
}
}
}
50 changes: 37 additions & 13 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

"github.com/go-gorp/gorp/v3"

"github.com/rubenv/sql-migrate/log"
"github.com/rubenv/sql-migrate/sqlparse"
)

Expand All @@ -42,12 +43,14 @@ type MigrationSet struct {
IgnoreUnknown bool
// DisableCreateTable disable the creation of the migration table
DisableCreateTable bool
// Logger is used to log additional information during the migration process.
Logger Logger
}

var migSet = MigrationSet{}

// NewMigrationSet returns a parametrized Migration object
func (ms MigrationSet) getTableName() string {
func (ms *MigrationSet) getTableName() string {
if ms.TableName == "" {
return "gorp_migrations"
}
Expand Down Expand Up @@ -124,6 +127,10 @@ func SetIgnoreUnknown(v bool) {
migSet.IgnoreUnknown = v
}

func SetLogger(l Logger) {
migSet.Logger = l
}

type Migration struct {
Id string
Up []string
Expand Down Expand Up @@ -448,7 +455,7 @@ func Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection)
}

// Returns the number of applied migrations.
func (ms MigrationSet) Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
func (ms *MigrationSet) Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
return ms.ExecMaxContext(context.Background(), db, dialect, m, dir, 0)
}

Expand All @@ -460,7 +467,7 @@ func ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSou
}

// Returns the number of applied migrations.
func (ms MigrationSet) ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
func (ms *MigrationSet) ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
return ms.ExecMaxContext(ctx, db, dialect, m, dir, 0)
}

Expand Down Expand Up @@ -504,12 +511,12 @@ func ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m Migra
}

// Returns the number of applied migrations.
func (ms MigrationSet) ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
func (ms *MigrationSet) ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
return ms.ExecMaxContext(context.Background(), db, dialect, m, dir, max)
}

// Returns the number of applied migrations, but applies with an input context.
func (ms MigrationSet) ExecMaxContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
func (ms *MigrationSet) ExecMaxContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
migrations, dbMap, err := ms.PlanMigration(db, dialect, m, dir, max)
if err != nil {
return 0, err
Expand All @@ -518,11 +525,11 @@ func (ms MigrationSet) ExecMaxContext(ctx context.Context, db *sql.DB, dialect s
}

// Returns the number of applied migrations.
func (ms MigrationSet) ExecVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
func (ms *MigrationSet) ExecVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
return ms.ExecVersionContext(context.Background(), db, dialect, m, dir, version)
}

func (ms MigrationSet) ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
func (ms *MigrationSet) ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
migrations, dbMap, err := ms.PlanMigrationToVersion(db, dialect, m, dir, version)
if err != nil {
return 0, err
Expand All @@ -531,9 +538,11 @@ func (ms MigrationSet) ExecVersionContext(ctx context.Context, db *sql.DB, diale
}

// Applies the planned migrations and returns the number of applied migrations.
func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection, migrations []*PlannedMigration, dbMap *gorp.DbMap) (int, error) {
func (m MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection, migrations []*PlannedMigration, dbMap *gorp.DbMap) (int, error) {
applied := 0
for _, migration := range migrations {
m.logger().Info(ctx, "Applying migration %s", migration.Id)

var executor SqlExecutor
var err error

Expand Down Expand Up @@ -563,6 +572,8 @@ func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection,

switch dir {
case Up:
m.logger().Info(ctx, "Migrating up %s", migration.Id)

err = executor.Insert(&MigrationRecord{
Id: migration.Id,
AppliedAt: time.Now(),
Expand All @@ -575,6 +586,8 @@ func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection,
return applied, newTxError(migration, err)
}
case Down:
m.logger().Info(ctx, "Migrating down %s", migration.Id)

_, err := executor.Delete(&MigrationRecord{
Id: migration.Id,
})
Expand All @@ -590,12 +603,16 @@ func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection,
}

if trans, ok := executor.(*gorp.Transaction); ok {
m.logger().Info(ctx, "Committing transaction for %s", migration.Id)

if err := trans.Commit(); err != nil {
return applied, newTxError(migration, err)
}
}

applied++

m.logger().Info(ctx, "Applied %d/%d migrations", applied, len(migrations))
}

return applied, nil
Expand All @@ -612,17 +629,17 @@ func PlanMigrationToVersion(db *sql.DB, dialect string, m MigrationSource, dir M
}

// Plan a migration.
func (ms MigrationSet) PlanMigration(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) ([]*PlannedMigration, *gorp.DbMap, error) {
func (ms *MigrationSet) PlanMigration(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) ([]*PlannedMigration, *gorp.DbMap, error) {
return ms.planMigrationCommon(db, dialect, m, dir, max, -1)
}

// Plan a migration to version.
func (ms MigrationSet) PlanMigrationToVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) ([]*PlannedMigration, *gorp.DbMap, error) {
func (ms *MigrationSet) PlanMigrationToVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) ([]*PlannedMigration, *gorp.DbMap, error) {
return ms.planMigrationCommon(db, dialect, m, dir, 0, version)
}

// A common method to plan a migration.
func (ms MigrationSet) planMigrationCommon(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int, version int64) ([]*PlannedMigration, *gorp.DbMap, error) {
func (ms *MigrationSet) planMigrationCommon(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int, version int64) ([]*PlannedMigration, *gorp.DbMap, error) {
dbMap, err := ms.getMigrationDbMap(db, dialect)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -822,7 +839,7 @@ func GetMigrationRecords(db *sql.DB, dialect string) ([]*MigrationRecord, error)
return migSet.GetMigrationRecords(db, dialect)
}

func (ms MigrationSet) GetMigrationRecords(db *sql.DB, dialect string) ([]*MigrationRecord, error) {
func (ms *MigrationSet) GetMigrationRecords(db *sql.DB, dialect string) ([]*MigrationRecord, error) {
dbMap, err := ms.getMigrationDbMap(db, dialect)
if err != nil {
return nil, err
Expand All @@ -838,7 +855,14 @@ func (ms MigrationSet) GetMigrationRecords(db *sql.DB, dialect string) ([]*Migra
return records, nil
}

func (ms MigrationSet) getMigrationDbMap(db *sql.DB, dialect string) (*gorp.DbMap, error) {
func (ms *MigrationSet) logger() Logger {
if migSet.Logger == nil {
migSet.Logger = log.NewDefaultLogger()
}
return migSet.Logger
}

func (ms *MigrationSet) getMigrationDbMap(db *sql.DB, dialect string) (*gorp.DbMap, error) {
d, ok := MigrationDialects[dialect]
if !ok {
return nil, fmt.Errorf("Unknown dialect: %s", dialect)
Expand Down
Loading