Skip to content

Commit

Permalink
Refactor Callback
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jan 16, 2016
1 parent dc23ae6 commit f1237e4
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 141 deletions.
129 changes: 82 additions & 47 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,45 @@ import (
"fmt"
)

type callback struct {
// defaultCallbacks hold default callbacks defined by gorm
var defaultCallbacks = &Callbacks{}

// Callbacks contains callbacks that used when CURD objects
// Field `creates` hold callbacks will be call when creating object
// Field `updates` hold callbacks will be call when updating object
// Field `deletes` hold callbacks will be call when deleting object
// Field `queries` hold callbacks will be call when querying object with query methods like Find, First, Related, Association...
// Field `rowQueries` hold callbacks will be call when querying object with Row, Rows...
// Field `processors` hold all callback processors, will be used to generate above callbacks in order
type Callbacks struct {
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
rowQueries []*func(scope *Scope)
processors []*callbackProcessor
processors []*CallbackProcessor
}

type callbackProcessor struct {
name string
before string
after string
replace bool
remove bool
typ string
processor *func(scope *Scope)
callback *callback
// callbackProcessor contains all informations for a callback
type CallbackProcessor struct {
name string // current callback's name
before string // register current callback before a callback
after string // register current callback after a callback
replace bool // replace callbacks with same name
remove bool // delete callbacks with same name
kind string // callback type: create, update, delete, query, row_query
processor *func(scope *Scope) // callback handler
parent *Callbacks
}

func (c *callback) addProcessor(typ string) *callbackProcessor {
cp := &callbackProcessor{typ: typ, callback: c}
func (c *Callbacks) addProcessor(kind string) *CallbackProcessor {
cp := &CallbackProcessor{kind: kind, parent: c}
c.processors = append(c.processors, cp)
return cp
}

func (c *callback) clone() *callback {
return &callback{
func (c *Callbacks) clone() *Callbacks {
return &Callbacks{
creates: c.creates,
updates: c.updates,
deletes: c.deletes,
Expand All @@ -40,57 +51,81 @@ func (c *callback) clone() *callback {
}
}

func (c *callback) Create() *callbackProcessor {
// Create could be used to register callbacks for creating object
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
// // business logic
// ...
//
// // set error if some thing wrong happened, will rollback the creating
// scope.Err(errors.New("error"))
// })
func (c *Callbacks) Create() *CallbackProcessor {
return c.addProcessor("create")
}

func (c *callback) Update() *callbackProcessor {
// Update could be used to register callbacks for updating object, refer `Create` for usage
func (c *Callbacks) Update() *CallbackProcessor {
return c.addProcessor("update")
}

func (c *callback) Delete() *callbackProcessor {
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
func (c *Callbacks) Delete() *CallbackProcessor {
return c.addProcessor("delete")
}

func (c *callback) Query() *callbackProcessor {
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
// refer `Create` for usage
func (c *Callbacks) Query() *CallbackProcessor {
return c.addProcessor("query")
}

func (c *callback) RowQuery() *callbackProcessor {
// Query could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
func (c *Callbacks) RowQuery() *CallbackProcessor {
return c.addProcessor("row_query")
}

func (cp *callbackProcessor) Before(name string) *callbackProcessor {
cp.before = name
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
cp.after = callbackName
return cp
}

func (cp *callbackProcessor) After(name string) *callbackProcessor {
cp.after = name
// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
cp.before = callbackName
return cp
}

func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) {
cp.name = name
cp.processor = &fc
cp.callback.sort()
// Register a new callback, refer `Callbacks.Create`
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
cp.name = callbackName
cp.processor = &callback
cp.parent.reorder()
}

func (cp *callbackProcessor) Remove(name string) {
fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum())
cp.name = name
// Remove a registered callback
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
func (cp *CallbackProcessor) Remove(callbackName string) {
fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
cp.name = callbackName
cp.remove = true
cp.callback.sort()
}

func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) {
fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum())
cp.name = name
cp.processor = &fc
cp.parent.reorder()
}

// Replace a registered callback with new callback
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
// scope.SetColumn("Created", now)
// scope.SetColumn("Updated", now)
// })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
cp.name = callbackName
cp.processor = &callback
cp.replace = true
cp.callback.sort()
cp.parent.reorder()
}

// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str {
Expand All @@ -100,8 +135,9 @@ func getRIndex(strs []string, str string) int {
return -1
}

func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
var sortCallbackProcessor func(c *callbackProcessor)
// sortProcessors sort callback processors based on its before, after, remove, replace
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
var sortCallbackProcessor func(c *CallbackProcessor)
var names, sortedNames = []string{}, []string{}

for _, cp := range cps {
Expand All @@ -113,7 +149,7 @@ func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
names = append(names, cp.name)
}

sortCallbackProcessor = func(c *callbackProcessor) {
sortCallbackProcessor = func(c *CallbackProcessor) {
if getRIndex(sortedNames, c.name) > -1 {
return
}
Expand Down Expand Up @@ -172,11 +208,12 @@ func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
return append(sortedFuncs, funcs...)
}

func (c *callback) sort() {
var creates, updates, deletes, queries, rowQueries []*callbackProcessor
// reorder all registered processors, and reset CURD callbacks
func (c *Callbacks) reorder() {
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor

for _, processor := range c.processors {
switch processor.typ {
switch processor.kind {
case "create":
creates = append(creates, processor)
case "update":
Expand All @@ -196,5 +233,3 @@ func (c *callback) sort() {
c.queries = sortProcessors(queries)
c.rowQueries = sortProcessors(rowQueries)
}

var DefaultCallback = &callback{processors: []*callbackProcessor{}}
18 changes: 9 additions & 9 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ func AfterCreate(scope *Scope) {
}

func init() {
DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction)
DefaultCallback.Create().Register("gorm:before_create", BeforeCreate)
DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
DefaultCallback.Create().Register("gorm:create", Create)
DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
defaultCallbacks.Create().Register("gorm:begin_transaction", BeginTransaction)
defaultCallbacks.Create().Register("gorm:before_create", BeforeCreate)
defaultCallbacks.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
defaultCallbacks.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
defaultCallbacks.Create().Register("gorm:create", Create)
defaultCallbacks.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
defaultCallbacks.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
defaultCallbacks.Create().Register("gorm:after_create", AfterCreate)
defaultCallbacks.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
}
10 changes: 5 additions & 5 deletions callback_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ func AfterDelete(scope *Scope) {
}

func init() {
DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction)
DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete)
DefaultCallback.Delete().Register("gorm:delete", Delete)
DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete)
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
defaultCallbacks.Delete().Register("gorm:begin_transaction", BeginTransaction)
defaultCallbacks.Delete().Register("gorm:before_delete", BeforeDelete)
defaultCallbacks.Delete().Register("gorm:delete", Delete)
defaultCallbacks.Delete().Register("gorm:after_delete", AfterDelete)
defaultCallbacks.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
}
6 changes: 3 additions & 3 deletions callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func AfterQuery(scope *Scope) {
}

func init() {
DefaultCallback.Query().Register("gorm:query", Query)
DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
DefaultCallback.Query().Register("gorm:preload", Preload)
defaultCallbacks.Query().Register("gorm:query", Query)
defaultCallbacks.Query().Register("gorm:after_query", AfterQuery)
defaultCallbacks.Query().Register("gorm:preload", Preload)
}
88 changes: 44 additions & 44 deletions callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,90 +23,90 @@ func afterCreate1(s *Scope) {}
func afterCreate2(s *Scope) {}

func TestRegisterCallback(t *testing.T) {
var callback = &callback{processors: []*callbackProcessor{}}
var callbacks = &Callbacks{}

callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("before_create2", beforeCreate2)
callback.Create().Register("create", create)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Register("after_create2", afterCreate2)
callbacks.Create().Register("before_create1", beforeCreate1)
callbacks.Create().Register("before_create2", beforeCreate2)
callbacks.Create().Register("create", create)
callbacks.Create().Register("after_create1", afterCreate1)
callbacks.Create().Register("after_create2", afterCreate2)

if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
if !equalFuncs(callbacks.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback")
}
}

func TestRegisterCallbackWithOrder(t *testing.T) {
var callback1 = &callback{processors: []*callbackProcessor{}}
callback1.Create().Register("before_create1", beforeCreate1)
callback1.Create().Register("create", create)
callback1.Create().Register("after_create1", afterCreate1)
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
var callbacks1 = &Callbacks{}
callbacks1.Create().Register("before_create1", beforeCreate1)
callbacks1.Create().Register("create", create)
callbacks1.Create().Register("after_create1", afterCreate1)
callbacks1.Create().Before("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callbacks1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order")
}

var callback2 = &callback{processors: []*callbackProcessor{}}
var callbacks2 = &Callbacks{}

callback2.Update().Register("create", create)
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
callback2.Update().Register("after_create2", afterCreate2)
callbacks2.Update().Register("create", create)
callbacks2.Update().Before("create").Register("before_create1", beforeCreate1)
callbacks2.Update().After("after_create2").Register("after_create1", afterCreate1)
callbacks2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
callbacks2.Update().Register("after_create2", afterCreate2)

if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
if !equalFuncs(callbacks2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order")
}
}

func TestRegisterCallbackWithComplexOrder(t *testing.T) {
var callback1 = &callback{processors: []*callbackProcessor{}}
var callbacks1 = &Callbacks{}

callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
callback1.Query().Register("before_create1", beforeCreate1)
callback1.Query().Register("after_create1", afterCreate1)
callbacks1.Query().Before("after_create1").After("before_create1").Register("create", create)
callbacks1.Query().Register("before_create1", beforeCreate1)
callbacks1.Query().Register("after_create1", afterCreate1)

if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
if !equalFuncs(callbacks1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
t.Errorf("register callback with order")
}

var callback2 = &callback{processors: []*callbackProcessor{}}
var callbacks2 = &Callbacks{}

callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
callback2.Delete().Register("after_create1", afterCreate1)
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
callbacks2.Delete().Before("after_create1").After("before_create1").Register("create", create)
callbacks2.Delete().Before("create").Register("before_create1", beforeCreate1)
callbacks2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
callbacks2.Delete().Register("after_create1", afterCreate1)
callbacks2.Delete().After("after_create1").Register("after_create2", afterCreate2)

if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
if !equalFuncs(callbacks2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback with order")
}
}

func replaceCreate(s *Scope) {}

func TestReplaceCallback(t *testing.T) {
var callback = &callback{processors: []*callbackProcessor{}}
var callbacks = &Callbacks{}

callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Replace("create", replaceCreate)
callbacks.Create().Before("after_create1").After("before_create1").Register("create", create)
callbacks.Create().Register("before_create1", beforeCreate1)
callbacks.Create().Register("after_create1", afterCreate1)
callbacks.Create().Replace("create", replaceCreate)

if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
if !equalFuncs(callbacks.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
t.Errorf("replace callback")
}
}

func TestRemoveCallback(t *testing.T) {
var callback = &callback{processors: []*callbackProcessor{}}
var callbacks = &Callbacks{}

callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Remove("create")
callbacks.Create().Before("after_create1").After("before_create1").Register("create", create)
callbacks.Create().Register("before_create1", beforeCreate1)
callbacks.Create().Register("after_create1", afterCreate1)
callbacks.Create().Remove("create")

if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
if !equalFuncs(callbacks.creates, []string{"beforeCreate1", "afterCreate1"}) {
t.Errorf("remove callback")
}
}
Loading

0 comments on commit f1237e4

Please sign in to comment.