Skip to content

Commit

Permalink
Fix update fields having default with empty value
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Sep 14, 2015
1 parent edc1f78 commit 2a46856
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
9 changes: 9 additions & 0 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ func Create(scope *Scope) {
if !field.IsBlank || !field.HasDefaultValue {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
} else if field.HasDefaultValue {
scope.InstanceSet("gorm:force_reload_after_create", true)
}
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
Expand Down Expand Up @@ -95,6 +97,12 @@ func Create(scope *Scope) {
}
}

func ForceReloadAfterCreate(scope *Scope) {
if _, ok := scope.InstanceGet("gorm:force_reload_after_create"); ok {
scope.DB().New().First(scope.Value)
}
}

func AfterCreate(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterCreate")
scope.CallMethodWithErrorCheck("AfterSave")
Expand All @@ -106,6 +114,7 @@ func init() {
DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
DefaultCallback.Create().Register("gorm:create", Create)
DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
Expand Down
4 changes: 1 addition & 3 deletions callback_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ func Update(scope *Scope) {
fields := scope.Fields()
for _, field := range fields {
if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
if !field.IsBlank || !field.HasDefaultValue {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
for _, dbName := range relationship.ForeignDBNames {
if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank {
Expand Down
8 changes: 8 additions & 0 deletions update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
if animal.Name != "amazing horse" {
t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name)
}

// When changing a field with a default value with blank value
animal.Name = ""
DB.Save(&animal)
DB.First(&animal, animal.Counter)
if animal.Name != "" {
t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name)
}
}

func TestUpdates(t *testing.T) {
Expand Down

0 comments on commit 2a46856

Please sign in to comment.