Skip to content

Commit

Permalink
Refactor Scope
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Sep 2, 2014
1 parent 9c7ff3d commit 953c347
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 84 deletions.
32 changes: 17 additions & 15 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,29 @@ func (association *Association) Find(value interface{}) *Association {

func (association *Association) Append(values ...interface{}) *Association {
scope := association.Scope
field := scope.IndirectValue().FieldByName(association.Column)
field := association.Field
fieldType := field.Field.Type()

for _, value := range values {
reflectvalue := reflect.ValueOf(value)
if reflectvalue.Kind() == reflect.Ptr {
if reflectvalue.Elem().Kind() == reflect.Struct {
if field.Type().Elem().Kind() == reflect.Ptr {
field.Set(reflect.Append(field, reflectvalue))
} else if field.Type().Elem().Kind() == reflect.Struct {
field.Set(reflect.Append(field, reflectvalue.Elem()))
if fieldType.Elem().Kind() == reflect.Ptr {
field.Set(reflect.Append(field.Field, reflectvalue))
} else if fieldType.Elem().Kind() == reflect.Struct {
field.Set(reflect.Append(field.Field, reflectvalue.Elem()))
}
} else if reflectvalue.Elem().Kind() == reflect.Slice {
if field.Type().Elem().Kind() == reflect.Ptr {
field.Set(reflect.AppendSlice(field, reflectvalue))
} else if field.Type().Elem().Kind() == reflect.Struct {
field.Set(reflect.AppendSlice(field, reflectvalue.Elem()))
if fieldType.Elem().Kind() == reflect.Ptr {
field.Set(reflect.AppendSlice(field.Field, reflectvalue))
} else if fieldType.Elem().Kind() == reflect.Struct {
field.Set(reflect.AppendSlice(field.Field, reflectvalue.Elem()))
}
}
} else if reflectvalue.Kind() == reflect.Struct && field.Type().Elem().Kind() == reflect.Struct {
field.Set(reflect.Append(field, reflectvalue))
} else if reflectvalue.Kind() == reflect.Slice && field.Type().Elem() == reflectvalue.Type().Elem() {
field.Set(reflect.AppendSlice(field, reflectvalue))
} else if reflectvalue.Kind() == reflect.Struct && fieldType.Elem().Kind() == reflect.Struct {
field.Set(reflect.Append(field.Field, reflectvalue))
} else if reflectvalue.Kind() == reflect.Slice && fieldType.Elem() == reflectvalue.Type().Elem() {
field.Set(reflect.AppendSlice(field.Field, reflectvalue))
} else {
association.err(errors.New("invalid association type"))
}
Expand Down Expand Up @@ -107,7 +109,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
relationship := association.Field.Relationship
scope := association.Scope
if relationship.Kind == "many_to_many" {
field := scope.IndirectValue().FieldByName(association.Column)
field := association.Field.Field

oldPrimaryKeys := association.getPrimaryKeys(field.Interface())
association.Append(values...)
Expand Down Expand Up @@ -154,7 +156,7 @@ func (association *Association) Count() int {
count := -1
relationship := association.Field.Relationship
scope := association.Scope
field := scope.IndirectValue().FieldByName(association.Column)
field := association.Field.Field
fieldValue := field.Interface()
newScope := scope.New(fieldValue)

Expand Down
4 changes: 3 additions & 1 deletion association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,12 @@ func TestManyToMany(t *testing.T) {
languageA := Language{Name: "AA"}
DB.Save(&languageA)
DB.Model(&User{Id: user.Id}).Association("Languages").Append(languageA)

languageC := Language{Name: "CC"}
DB.Save(&languageC)
DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC})
DB.Model(&User{Id: user.Id}).Association("Languages").Append([]Language{{Name: "DD"}, {Name: "EE"}})

DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}})

totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"}

Expand Down
2 changes: 1 addition & 1 deletion callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func Create(scope *Scope) {
for _, field := range scope.Fields() {
if field.IsNormal && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Value))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
}
}

Expand Down
26 changes: 12 additions & 14 deletions callback_shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,18 @@ func SaveBeforeAssociations(scope *Scope) {
if !field.IsBlank && !field.IsIgnored {
relationship := field.Relationship
if relationship != nil && relationship.Kind == "belongs_to" {
value := reflect.ValueOf(field.Value)
value := field.Field
newDB := scope.NewDB()

if value.CanAddr() {
scope.Err(newDB.Save(value.Addr().Interface()).Error)
} else {
if !value.CanAddr() {
// If can't take address, then clone the value and set it back
value = reflect.New(reflect.ValueOf(field.Value).Type()).Elem()
for _, f := range newDB.NewScope(field.Value).Fields() {
value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
value = reflect.New(value.Type()).Elem()
for _, f := range newDB.NewScope(field.Field.Interface()).Fields() {
value.FieldByName(f.Name).Set(reflect.ValueOf(f.Field.Interface()))
}
scope.Err(newDB.Save(value.Addr().Interface()).Error)
scope.SetColumn(field.Name, value.Interface())
}
scope.Err(newDB.Save(value.Addr().Interface()).Error)

if relationship.ForeignKey != "" {
scope.SetColumn(relationship.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue())
Expand All @@ -48,7 +46,7 @@ func SaveAfterAssociations(scope *Scope) {
relationship := field.Relationship
if relationship != nil &&
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
value := reflect.ValueOf(field.Value)
value := field.Field

switch value.Kind() {
case reflect.Slice:
Expand Down Expand Up @@ -89,14 +87,14 @@ func SaveAfterAssociations(scope *Scope) {
newDB := scope.NewDB()
if value.CanAddr() {
if relationship.ForeignKey != "" {
newDB.NewScope(field.Value).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
}
scope.Err(newDB.Save(field.Value).Error)
scope.Err(newDB.Save(value.Addr().Interface()).Error)
} else {
destValue := reflect.New(reflect.TypeOf(field.Value)).Elem()
destValue := reflect.New(field.Field.Type()).Elem()

for _, f := range newDB.NewScope(field.Value).Fields() {
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
for _, f := range newDB.NewScope(field.Field.Interface()).Fields() {
destValue.FieldByName(f.Name).Set(f.Field)
}

elem := destValue.Addr().Interface()
Expand Down
2 changes: 1 addition & 1 deletion callback_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func Update(scope *Scope) {
} else {
for _, field := range scope.Fields() {
if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value)))
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
}
Expand Down
26 changes: 21 additions & 5 deletions field.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ type Field struct {
Name string
DBName string
Field reflect.Value
Value interface{}
Tag reflect.StructTag
Relationship *relationship
IsNormal bool
Expand All @@ -26,12 +25,29 @@ type Field struct {
IsPrimaryKey bool
}

func (f *Field) IsScanner() bool {
_, isScanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner)
func (field *Field) IsScanner() bool {
_, isScanner := reflect.New(field.Field.Type()).Interface().(sql.Scanner)
return isScanner
}

func (f *Field) IsTime() bool {
_, isTime := f.Value.(time.Time)
func (field *Field) IsTime() bool {
_, isTime := field.Field.Interface().(time.Time)
return isTime
}

func (field *Field) Set(value interface{}) (result bool) {
if field.Field.IsValid() && field.Field.CanAddr() {
result = true
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
scanner.Scan(value)
} else if reflect.TypeOf(value).ConvertibleTo(field.Field.Type()) {
field.Field.Set(reflect.ValueOf(value).Convert(field.Field.Type()))
} else {
result = false
}
}
if result {
field.IsBlank = isBlank(field.Field)
}
return
}
33 changes: 17 additions & 16 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,17 @@ func (scope *Scope) FieldValueByName(name string) (interface{}, bool) {
}

// SetColumn to set the column's value
func (scope *Scope) SetColumn(column string, value interface{}) bool {
if scope.Value == nil {
return false
}
for _, field := range scope.Fields() {
if field.Name == column || field.DBName == column {
return setFieldValue(field.Field, value)
func (scope *Scope) SetColumn(column interface{}, value interface{}) bool {
if field, ok := column.(*Field); ok {
return field.Set(value)
} else if str, ok := column.(string); ok {
if scope.Value == nil {
return false
}
for _, field := range scope.Fields() {
if field.Name == str || field.DBName == str {
return field.Set(value)
}
}
}
return false
Expand Down Expand Up @@ -267,11 +271,9 @@ func (scope *Scope) CombinedConditionSql() string {
}

func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
if scope.Value != nil {
if scope.IndirectValue().Kind() == reflect.Struct {
if f, ok := scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok {
return scope.fieldFromStruct(f, true)[0], true
}
for _, field := range scope.Fields() {
if field.Name == name {
return field, true
}
}
return nil, false
Expand All @@ -285,7 +287,6 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio
value := scope.IndirectValue().FieldByName(fieldStruct.Name)
indirectValue := reflect.Indirect(value)
field.Field = value
field.Value = value.Interface()
field.IsBlank = isBlank(value)

// Search for primary key tag identifier
Expand Down Expand Up @@ -416,9 +417,9 @@ func (scope *Scope) Fields(noRelations ...bool) map[string]*Field {
}
}

// if withRelation {
// scope.fields = fields
// }
if withRelation {
scope.fields = fields
}

return fields
}
Expand Down
22 changes: 11 additions & 11 deletions scope_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
var sqls []string
for _, field := range scope.New(value).Fields() {
if !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Value)))
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
Expand Down Expand Up @@ -103,7 +103,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
var sqls []string
for _, field := range scope.New(value).Fields() {
if !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Value)))
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
Expand Down Expand Up @@ -264,17 +264,17 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
}

for key, value := range values {
if field := data.FieldByName(SnakeToUpperCamel(key)); field.IsValid() {
if field, ok := scope.FieldByName(SnakeToUpperCamel(key)); ok && field.Field.IsValid() {
func() {
defer func() {
if err := recover(); err != nil {
hasUpdate = true
setFieldValue(field, value)
field.Set(value)
}
}()

if field.Interface() != value {
switch field.Kind() {
if field.Field.Interface() != value {
switch field.Field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
if s, ok := value.(string); ok {
i, err := strconv.Atoi(s)
Expand All @@ -283,13 +283,13 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
}
}

if field.Int() != reflect.ValueOf(value).Int() {
if field.Field.Int() != reflect.ValueOf(value).Int() {
hasUpdate = true
setFieldValue(field, value)
field.Set(value)
}
default:
hasUpdate = true
setFieldValue(field, value)
field.Set(value)
}
}
}()
Expand Down Expand Up @@ -324,8 +324,8 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) {
additionalType = additionalType + "DEFAULT " + value
}

value := field.Value
reflectValue := reflect.ValueOf(value)
value := field.Field.Interface()
reflectValue := field.Field

switch reflectValue.Kind() {
case reflect.Slice:
Expand Down
4 changes: 2 additions & 2 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func ToSnake(u string) string {
}

s := strings.ToLower(buf.String())
go smap.Set(u, s)
smap.Set(u, s)
return s
}

Expand All @@ -86,7 +86,7 @@ func SnakeToUpperCamel(s string) string {
}

u := buf.String()
go umap.Set(s, u)
umap.Set(s, u)
return u
}

Expand Down
20 changes: 2 additions & 18 deletions utils_private.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gorm

import (
"database/sql"
"fmt"
"os"
"reflect"
Expand All @@ -11,7 +10,7 @@ import (
)

func fileWithLineNum() string {
for i := 1; i < 15; i++ {
for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)
Expand All @@ -20,21 +19,6 @@ func fileWithLineNum() string {
return ""
}

func setFieldValue(field reflect.Value, value interface{}) (result bool) {
result = false
if field.IsValid() && field.CanAddr() {
result = true
if scanner, ok := field.Addr().Interface().(sql.Scanner); ok {
scanner.Scan(value)
} else if reflect.TypeOf(value).ConvertibleTo(field.Type()) {
field.Set(reflect.ValueOf(value).Convert(field.Type()))
} else {
result = false
}
}
return
}

func isBlank(value reflect.Value) bool {
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
}
Expand Down Expand Up @@ -82,7 +66,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
scope := Scope{Value: values}
for _, field := range scope.Fields() {
if !field.IsBlank {
attrs[field.DBName] = field.Value
attrs[field.DBName] = field.Field.Interface()
}
}
}
Expand Down

0 comments on commit 953c347

Please sign in to comment.