Skip to content

Commit

Permalink
Refactor callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jan 17, 2016
1 parent 09f46f0 commit de73d30
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 48 deletions.
15 changes: 13 additions & 2 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Callback struct {
processors []*CallbackProcessor
}

// callbackProcessor contains all informations for a callback
// CallbackProcessor contains all informations for a callback
type CallbackProcessor struct {
name string // current callback's name
before string // register current callback before a callback
Expand Down Expand Up @@ -79,7 +79,7 @@ func (c *Callback) Query() *CallbackProcessor {
return c.addProcessor("query")
}

// Query could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
func (c *Callback) RowQuery() *CallbackProcessor {
return c.addProcessor("row_query")
}
Expand Down Expand Up @@ -125,6 +125,17 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
cp.parent.reorder()
}

// Get registered callback
// db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
for _, processor := range cp.parent.processors {
if processor.name == callbackName && processor.kind == cp.kind && !cp.remove {
return *cp.processor
}
}
return nil
}

// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
Expand Down
28 changes: 14 additions & 14 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ import (
"strings"
)

func BeforeCreate(scope *Scope) {
func beforeCreateCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("BeforeSave")
scope.CallMethodWithErrorCheck("BeforeCreate")
}

func UpdateTimeStampWhenCreate(scope *Scope) {
func updateTimeStampForCreateCallback(scope *Scope) {
if !scope.HasError() {
now := NowFunc()
scope.SetColumn("CreatedAt", now)
scope.SetColumn("UpdatedAt", now)
}
}

func Create(scope *Scope) {
func createCallback(scope *Scope) {
defer scope.trace(NowFunc())

if !scope.HasError() {
Expand Down Expand Up @@ -102,25 +102,25 @@ func Create(scope *Scope) {
}
}

func ForceReloadAfterCreate(scope *Scope) {
func forceReloadAfterCreateCallback(scope *Scope) {
if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
scope.DB().New().Select(columns.([]string)).First(scope.Value)
}
}

func AfterCreate(scope *Scope) {
func afterCreateCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterCreate")
scope.CallMethodWithErrorCheck("AfterSave")
}

func init() {
defaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction)
defaultCallback.Create().Register("gorm:before_create", BeforeCreate)
defaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
defaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
defaultCallback.Create().Register("gorm:create", Create)
defaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
defaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
defaultCallback.Create().Register("gorm:after_create", AfterCreate)
defaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
defaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
defaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
defaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
defaultCallback.Create().Register("gorm:update_time_stamp_when_create", updateTimeStampForCreateCallback)
defaultCallback.Create().Register("gorm:create", createCallback)
defaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
defaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
defaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
defaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}
16 changes: 8 additions & 8 deletions callback_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package gorm

import "fmt"

func BeforeDelete(scope *Scope) {
func beforeDeleteCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("BeforeDelete")
}

func Delete(scope *Scope) {
func deleteCallback(scope *Scope) {
if !scope.HasError() {
if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
scope.Raw(
Expand All @@ -23,14 +23,14 @@ func Delete(scope *Scope) {
}
}

func AfterDelete(scope *Scope) {
func afterDeleteCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterDelete")
}

func init() {
defaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction)
defaultCallback.Delete().Register("gorm:before_delete", BeforeDelete)
defaultCallback.Delete().Register("gorm:delete", Delete)
defaultCallback.Delete().Register("gorm:after_delete", AfterDelete)
defaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
defaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
defaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
defaultCallback.Delete().Register("gorm:delete", deleteCallback)
defaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
defaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}
10 changes: 5 additions & 5 deletions callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"reflect"
)

func Query(scope *Scope) {
func queryCallback(scope *Scope) {
defer scope.trace(NowFunc())

var (
Expand Down Expand Up @@ -78,12 +78,12 @@ func Query(scope *Scope) {
}
}

func AfterQuery(scope *Scope) {
func afterQueryCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterFind")
}

func init() {
defaultCallback.Query().Register("gorm:query", Query)
defaultCallback.Query().Register("gorm:after_query", AfterQuery)
defaultCallback.Query().Register("gorm:preload", Preload)
defaultCallback.Query().Register("gorm:query", queryCallback)
defaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
defaultCallback.Query().Register("gorm:preload", preloadCallback)
}
2 changes: 1 addition & 1 deletion preload.go → callback_query_preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"strings"
)

func Preload(scope *Scope) {
func preloadCallback(scope *Scope) {
if scope.Search.preload == nil || scope.HasError() {
return
}
Expand Down
8 changes: 4 additions & 4 deletions callback_shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ package gorm

import "reflect"

func BeginTransaction(scope *Scope) {
func beginTransactionCallback(scope *Scope) {
scope.Begin()
}

func CommitOrRollbackTransaction(scope *Scope) {
func commitOrRollbackTransactionCallback(scope *Scope) {
scope.CommitOrRollback()
}

func SaveBeforeAssociations(scope *Scope) {
func saveBeforeAssociationsCallback(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
Expand All @@ -32,7 +32,7 @@ func SaveBeforeAssociations(scope *Scope) {
}
}

func SaveAfterAssociations(scope *Scope) {
func saveAfterAssociationsCallback(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
Expand Down
28 changes: 14 additions & 14 deletions callback_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"strings"
)

func AssignUpdateAttributes(scope *Scope) {
func assignUpdateAttributesCallback(scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
protected, ok := scope.Get("gorm:ignore_protected_attrs")
Expand All @@ -24,20 +24,20 @@ func AssignUpdateAttributes(scope *Scope) {
}
}

func BeforeUpdate(scope *Scope) {
func beforeUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethodWithErrorCheck("BeforeSave")
scope.CallMethodWithErrorCheck("BeforeUpdate")
}
}

func UpdateTimeStampWhenUpdate(scope *Scope) {
func updateTimeStampForUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc())
}
}

func Update(scope *Scope) {
func updateCallback(scope *Scope) {
if !scope.HasError() {
var sqls []string

Expand Down Expand Up @@ -75,21 +75,21 @@ func Update(scope *Scope) {
}
}

func AfterUpdate(scope *Scope) {
func afterUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethodWithErrorCheck("AfterUpdate")
scope.CallMethodWithErrorCheck("AfterSave")
}
}

func init() {
defaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes)
defaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction)
defaultCallback.Update().Register("gorm:before_update", BeforeUpdate)
defaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations)
defaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
defaultCallback.Update().Register("gorm:update", Update)
defaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
defaultCallback.Update().Register("gorm:after_update", AfterUpdate)
defaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
defaultCallback.Update().Register("gorm:assign_update_attributes", assignUpdateAttributesCallback)
defaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
defaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
defaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
defaultCallback.Update().Register("gorm:update_time_stamp_when_update", updateTimeStampForUpdateCallback)
defaultCallback.Update().Register("gorm:update", updateCallback)
defaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
defaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
defaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}

0 comments on commit de73d30

Please sign in to comment.