Skip to content

Commit

Permalink
Set nopLogger to DefaultCallback for avoid nil pointer dereference (g…
Browse files Browse the repository at this point in the history
  • Loading branch information
zaneli authored and jinzhu committed Dec 5, 2019
1 parent 0aba7ff commit e8c07b5
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
9 changes: 2 additions & 7 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package gorm
import "fmt"

// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}
var DefaultCallback = &Callback{logger: nopLogger{}}

// Callback is a struct that contains all CRUD callbacks
// Field `creates` contains callbacks will be call when creating object
Expand Down Expand Up @@ -101,12 +101,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
}
}

if cp.logger != nil {
// note cp.logger will be nil during the default gorm callback registrations
// as they occur within init() blocks. However, any user-registered callbacks
// will happen after cp.logger exists (as the default logger or user-specified).
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
}
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp)
Expand Down
30 changes: 30 additions & 0 deletions callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,33 @@ func TestGetCallback(t *testing.T) {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
}
}

func TestUseDefaultCallback(t *testing.T) {
createCallbackName := "gorm:test_use_default_callback_for_create"
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
// nop
})
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
}
gorm.DefaultCallback.Create().Remove(createCallbackName)
if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
}

updateCallbackName := "gorm:test_use_default_callback_for_update"
scopeValueName := "gorm:test_use_default_callback_for_update_value"
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 1)
})
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 2)
})

scope := DB.NewScope(nil)
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
callback(scope)
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
}
}
4 changes: 4 additions & 0 deletions logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,7 @@ type Logger struct {
func (logger Logger) Print(values ...interface{}) {
logger.Println(LogFormatter(values...)...)
}

type nopLogger struct{}

func (nopLogger) Print(values ...interface{}) {}

0 comments on commit e8c07b5

Please sign in to comment.