Skip to content

Commit

Permalink
Refactor Scope updatedAttrsWithValues
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Mar 9, 2016
1 parent a0aa21a commit 8de97c2
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 73 deletions.
10 changes: 4 additions & 6 deletions callback_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@ func init() {
// assignUpdatingAttributesCallback assign updating attributes to model
func assignUpdatingAttributesCallback(scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate {
scope.InstanceSet("gorm:update_attrs", updateMaps)
} else {
scope.SkipLeft()
}
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
scope.InstanceSet("gorm:update_attrs", updateMaps)
} else {
scope.SkipLeft()
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
}
c.NewScope(out).inlineCondition(where...).initialize()
} else {
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs))
c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs)
}
return c
}
Expand Down
62 changes: 45 additions & 17 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -793,27 +793,55 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
return scope
}

func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
func convertInterfaceToMap(values interface{}) map[string]interface{} {
var attrs = map[string]interface{}{}

switch value := values.(type) {
case map[string]interface{}:
return value
case []interface{}:
for _, v := range value {
for key, value := range convertInterfaceToMap(v) {
attrs[key] = value
}
}
case interface{}:
reflectValue := reflect.ValueOf(values)

switch reflectValue.Kind() {
case reflect.Map:
for _, key := range reflectValue.MapKeys() {
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
}
default:
for _, field := range (&Scope{Value: values}).Fields() {
if !field.IsBlank {
attrs[field.DBName] = field.Field.Interface()
}
}
}
}
return attrs
}

func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) {
if scope.IndirectValue().Kind() != reflect.Struct {
return values, true
return convertInterfaceToMap(value), true
}

results = map[string]interface{}{}
for key, value := range values {

for key, value := range convertInterfaceToMap(value) {
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
if _, ok := value.(*expr); ok {
hasUpdate = true
results[field.DBName] = value
} else if !equalAsString(field.Field.Interface(), value) {
field.Set(value)
if field.IsNormal {
hasUpdate = true
results[field.DBName] = field.Field.Interface()
}
}
if _, ok := value.(*expr); ok {
hasUpdate = true
results[field.DBName] = value
} else {
field.Set(value)
if field.IsNormal {
hasUpdate = true
results[field.DBName] = field.Field.Interface()
}
}
}
}
Expand All @@ -836,10 +864,10 @@ func (scope *Scope) rows() (*sql.Rows, error) {

func (scope *Scope) initialize() *Scope {
for _, clause := range scope.Search.whereConditions {
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
scope.updatedAttrsWithValues(clause["query"])
}
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
scope.updatedAttrsWithValues(scope.Search.initAttrs)
scope.updatedAttrsWithValues(scope.Search.assignAttrs)
return scope
}

Expand Down
18 changes: 0 additions & 18 deletions update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@ func TestUpdate(t *testing.T) {
DB.First(&product1, product1.Id)
DB.First(&product2, product2.Id)
updatedAt1 := product1.UpdatedAt
updatedAt2 := product2.UpdatedAt

var product3 Product
DB.First(&product3, product2.Id).Update("code", "product2newcode")
if updatedAt2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updatedAt should not be updated if nothing changed")
}

if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() {
t.Errorf("Product1 should not be updated")
Expand Down Expand Up @@ -135,19 +128,8 @@ func TestUpdates(t *testing.T) {

DB.First(&product1, product1.Id)
DB.First(&product2, product2.Id)
updatedAt1 := product1.UpdatedAt
updatedAt2 := product2.UpdatedAt

var product3 Product
DB.First(&product3, product1.Id).Updates(Product{Code: "product1newcode", Price: 100})
if product3.Code != "product1newcode" || product3.Price != 100 {
t.Errorf("Record should be updated with struct")
}

if updatedAt1.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updatedAt should not be updated if nothing changed")
}

if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() {
t.Errorf("Product2 should not be updated")
}
Expand Down
31 changes: 0 additions & 31 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,37 +199,6 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) {
return
}

func convertInterfaceToMap(values interface{}) map[string]interface{} {
attrs := map[string]interface{}{}

switch value := values.(type) {
case map[string]interface{}:
return value
case []interface{}:
for _, v := range value {
for key, value := range convertInterfaceToMap(v) {
attrs[key] = value
}
}
case interface{}:
reflectValue := reflect.ValueOf(values)

switch reflectValue.Kind() {
case reflect.Map:
for _, key := range reflectValue.MapKeys() {
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
}
default:
for _, field := range (&Scope{Value: values}).Fields() {
if !field.IsBlank {
attrs[field.DBName] = field.Field.Interface()
}
}
}
}
return attrs
}

func equalAsString(a interface{}, b interface{}) bool {
return toString(a) == toString(b)
}
Expand Down

0 comments on commit 8de97c2

Please sign in to comment.