Skip to content

Commit

Permalink
add an override on the DB instance instead of using the global NowFun…
Browse files Browse the repository at this point in the history
…c. (go-gorm#2142)
  • Loading branch information
rubensayshi authored and jinzhu committed Jun 10, 2019
1 parent af01854 commit 712c465
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 8 deletions.
4 changes: 2 additions & 2 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func beforeCreateCallback(scope *Scope) {
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
func updateTimeStampForCreateCallback(scope *Scope) {
if !scope.HasError() {
now := NowFunc()
now := scope.db.nowFunc()

if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
if createdAtField.IsBlank {
Expand All @@ -50,7 +50,7 @@ func updateTimeStampForCreateCallback(scope *Scope) {
// createCallback the callback used to insert data into database
func createCallback(scope *Scope) {
if !scope.HasError() {
defer scope.trace(NowFunc())
defer scope.trace(scope.db.nowFunc())

var (
columns, placeholders []string
Expand Down
2 changes: 1 addition & 1 deletion callback_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func deleteCallback(scope *Scope) {
"UPDATE %v SET %v=%v%v%v",
scope.QuotedTableName(),
scope.Quote(deletedAtField.DBName),
scope.AddToVars(NowFunc()),
scope.AddToVars(scope.db.nowFunc()),
addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption),
)).Exec()
Expand Down
2 changes: 1 addition & 1 deletion callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func queryCallback(scope *Scope) {
return
}

defer scope.trace(NowFunc())
defer scope.trace(scope.db.nowFunc())

var (
isSlice, isPtr bool
Expand Down
2 changes: 1 addition & 1 deletion callback_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func beforeUpdateCallback(scope *Scope) {
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
func updateTimeStampForUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc())
scope.SetColumn("UpdatedAt", scope.db.nowFunc())
}
}

Expand Down
40 changes: 40 additions & 0 deletions create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,46 @@ func TestCreateWithExistingTimestamp(t *testing.T) {
}
}

func TestCreateWithNowFuncOverride(t *testing.T) {
user1 := User{Name: "CreateUserTimestampOverride"}

timeA := now.MustParse("2016-01-01")

// do DB.New() because we don't want this test to affect other tests
db1 := DB.New()
// set the override to use static timeA
db1.SetNowFuncOverride(func() time.Time {
return timeA
})
// call .New again to check the override is carried over as well during clone
db1 = db1.New()

db1.Save(&user1)

if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
t.Errorf("CreatedAt be using the nowFuncOverride")
}
if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
t.Errorf("UpdatedAt be using the nowFuncOverride")
}

// now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set
// to make sure that setting it only affected the above instance

user2 := User{Name: "CreateUserTimestampOverrideNoMore"}

db2 := DB.New()

db2.Save(&user2)

if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) {
t.Errorf("CreatedAt no longer be using the nowFuncOverride")
}
if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) {
t.Errorf("UpdatedAt no longer be using the nowFuncOverride")
}
}

type AutoIncrementUser struct {
User
Sequence uint `gorm:"AUTO_INCREMENT"`
Expand Down
20 changes: 20 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type DB struct {
callbacks *Callback
dialect Dialect
singularTable bool

// function to be used to override the creating of a new timestamp
nowFuncOverride func() time.Time
}

type logModeValue int
Expand Down Expand Up @@ -158,6 +161,22 @@ func (s *DB) LogMode(enable bool) *DB {
return s
}

// SetNowFuncOverride set the function to be used when creating a new timestamp
func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB {
s.nowFuncOverride = nowFuncOverride
return s
}

// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set,
// otherwise defaults to the global NowFunc()
func (s *DB) nowFunc() time.Time {
if s.nowFuncOverride != nil {
return s.nowFuncOverride()
}

return NowFunc()
}

// BlockGlobalUpdate if true, generates an error on update/delete without where clause.
// This is to prevent eventual error with empty objects updates/deletions
func (s *DB) BlockGlobalUpdate(enable bool) *DB {
Expand Down Expand Up @@ -800,6 +819,7 @@ func (s *DB) clone() *DB {
Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate,
dialect: newDialect(s.dialect.GetName(), s.db),
nowFuncOverride: s.nowFuncOverride,
}

s.values.Range(func(k, v interface{}) bool {
Expand Down
6 changes: 3 additions & 3 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func (scope *Scope) Raw(sql string) *Scope {

// Exec perform generated SQL
func (scope *Scope) Exec() *Scope {
defer scope.trace(NowFunc())
defer scope.trace(scope.db.nowFunc())

if !scope.HasError() {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
Expand Down Expand Up @@ -932,7 +932,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
}

func (scope *Scope) row() *sql.Row {
defer scope.trace(NowFunc())
defer scope.trace(scope.db.nowFunc())

result := &RowQueryResult{}
scope.InstanceSet("row_query_result", result)
Expand All @@ -942,7 +942,7 @@ func (scope *Scope) row() *sql.Row {
}

func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.trace(NowFunc())
defer scope.trace(scope.db.nowFunc())

result := &RowsQueryResult{}
scope.InstanceSet("row_query_result", result)
Expand Down

0 comments on commit 712c465

Please sign in to comment.