diff --git a/association.go b/association.go index 911626ab7..4f94d15f6 100644 --- a/association.go +++ b/association.go @@ -7,12 +7,11 @@ import ( ) type Association struct { - Scope *Scope - PrimaryKey interface{} - PrimaryType interface{} - Column string - Error error - Field *Field + Scope *Scope + PrimaryKey interface{} + Column string + Error error + Field *Field } func (association *Association) setErr(err error) *Association { @@ -158,11 +157,11 @@ func (association *Association) Count() int { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey) if relationship.ForeignType != "" { - countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), association.PrimaryType) + countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), scope.TableName()) } countScope.Count(&count) } else if relationship.Kind == "belongs_to" { - if v, err := scope.FieldValueByName(association.Column); err == nil { + if v, ok := scope.FieldByName(association.Column); ok { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, v).Count(&count) } diff --git a/callback_create.go b/callback_create.go index fe0ba28d4..1bf44a3db 100644 --- a/callback_create.go +++ b/callback_create.go @@ -26,11 +26,10 @@ func Create(scope *Scope) { var sqls, columns []string for _, field := range scope.Fields() { if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) { - if field.DefaultValue != nil && field.IsBlank { - continue + if !field.IsBlank || field.DefaultValue == nil { + columns = append(columns, scope.Quote(field.DBName)) + sqls = append(sqls, scope.AddToVars(field.Field.Interface())) } - columns = append(columns, scope.Quote(field.DBName)) - sqls = append(sqls, scope.AddToVars(field.Field.Interface())) } } diff --git a/callback_update.go b/callback_update.go index 0e32352fa..fc8ee3ad2 100644 --- a/callback_update.go +++ b/callback_update.go @@ -41,18 +41,16 @@ func Update(scope *Scope) { if !scope.HasError() { var sqls []string - updateAttrs, ok := scope.InstanceGet("gorm:update_attrs") - if ok { + if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { for key, value := range updateAttrs.(map[string]interface{}) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) } } else { for _, field := range scope.Fields() { - if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored { - if field.DefaultValue != nil && field.IsBlank { - continue + if !field.IsPrimaryKey && field.IsNormal { + if !field.IsBlank || field.DefaultValue == nil { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } } @@ -68,8 +66,7 @@ func Update(scope *Scope) { } func AfterUpdate(scope *Scope) { - _, ok := scope.Get("gorm:update_column") - if !ok { + if _, ok := scope.Get("gorm:update_column"); !ok { scope.CallMethod("AfterUpdate") scope.CallMethod("AfterSave") } diff --git a/common_dialect.go b/common_dialect.go index 37de7399e..7013df06b 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "strings" + "time" ) type commonDialect struct{} @@ -36,7 +37,7 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int) string { } return "VARCHAR(65532)" case reflect.Struct: - if value.Type() == timeType { + if _, ok := value.Interface().(time.Time); ok { return "TIMESTAMP" } default: diff --git a/dialect.go b/dialect.go index 42b5c742e..6e76b437a 100644 --- a/dialect.go +++ b/dialect.go @@ -3,11 +3,8 @@ package gorm import ( "fmt" "reflect" - "time" ) -var timeType = reflect.TypeOf(time.Time{}) - type Dialect interface { BinVar(i int) string SupportLastInsertId() bool diff --git a/field.go b/field.go index 66be9c314..7e3f1400d 100644 --- a/field.go +++ b/field.go @@ -12,7 +12,7 @@ type Field struct { Field reflect.Value } -func (field *Field) Set(value interface{}) (err error) { +func (field *Field) Set(value interface{}) error { if !field.Field.IsValid() { return errors.New("field value not valid") } @@ -26,16 +26,26 @@ func (field *Field) Set(value interface{}) (err error) { } if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok { - scanner.Scan(value) - } else if reflect.TypeOf(value).ConvertibleTo(field.Field.Type()) { - field.Field.Set(reflect.ValueOf(value).Convert(field.Field.Type())) + if v, ok := value.(reflect.Value); ok { + scanner.Scan(v.Interface()) + } else { + scanner.Scan(value) + } } else { - return errors.New("could not convert argument") + reflectValue, ok := value.(reflect.Value) + if !ok { + reflectValue = reflect.ValueOf(value) + } + + if reflectValue.Type().ConvertibleTo(field.Field.Type()) { + field.Field.Set(reflectValue.Convert(field.Field.Type())) + } else { + return errors.New("could not convert argument") + } } field.IsBlank = isBlank(field.Field) - - return + return nil } // Fields get value's fields @@ -44,8 +54,14 @@ func (scope *Scope) Fields() map[string]*Field { fields := map[string]*Field{} structFields := scope.GetStructFields() + indirectValue := scope.IndirectValue() + isStruct := indirectValue.Kind() == reflect.Struct for _, structField := range structFields { - fields[structField.DBName] = scope.getField(structField) + if isStruct { + fields[structField.DBName] = getField(indirectValue, structField) + } else { + fields[structField.DBName] = &Field{StructField: structField, IsBlank: true} + } } scope.fields = fields @@ -53,15 +69,12 @@ func (scope *Scope) Fields() map[string]*Field { return scope.fields } -func (scope *Scope) getField(structField *StructField) *Field { - field := Field{StructField: structField} - indirectValue := scope.IndirectValue() - if len(structField.Names) > 0 && indirectValue.Kind() == reflect.Struct { - for _, name := range structField.Names { - indirectValue = reflect.Indirect(indirectValue).FieldByName(name) - } - field.Field = indirectValue +func getField(indirectValue reflect.Value, structField *StructField) *Field { + field := &Field{StructField: structField} + for _, name := range structField.Names { + indirectValue = reflect.Indirect(indirectValue).FieldByName(name) } + field.Field = indirectValue field.IsBlank = isBlank(indirectValue) - return &field + return field } diff --git a/logger.go b/logger.go index 001b2165c..48569561a 100644 --- a/logger.go +++ b/logger.go @@ -21,21 +21,21 @@ var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} // Format log var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`) -func (logger Logger) Print(v ...interface{}) { - if len(v) > 1 { - level := v[0] +func (logger Logger) Print(values ...interface{}) { + if len(values) > 1 { + level := values[0] currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" - source := fmt.Sprintf("\033[35m(%v)\033[0m", v[1]) + source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) messages := []interface{}{source, currentTime} if level == "sql" { // duration - messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(v[2].(time.Duration).Nanoseconds()/1e4)/100.0)) + messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) // sql - messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(v[3].(string), "'%v'"), v[4].([]interface{})...)) + messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "'%v'"), values[4].([]interface{})...)) } else { messages = append(messages, "\033[31;1m") - messages = append(messages, v[2:]...) + messages = append(messages, values[2:]...) messages = append(messages, "\033[0m") } logger.Println(messages...) diff --git a/main.go b/main.go index 77cd112c6..7e7097262 100644 --- a/main.go +++ b/main.go @@ -30,7 +30,6 @@ type DB struct { logMode int logger logger dialect Dialect - tagIdentifier string singularTable bool source string values map[string]interface{} @@ -39,34 +38,39 @@ type DB struct { func Open(dialect string, args ...interface{}) (DB, error) { var db DB var err error - var source string - var dbSql sqlCommon if len(args) == 0 { err = errors.New("invalid database source") - } + } else { + var source string + var dbSql sqlCommon + + switch value := args[0].(type) { + case string: + var driver = dialect + if len(args) == 1 { + source = value + } else if len(args) >= 2 { + driver = value + source = args[1].(string) + } + dbSql, err = sql.Open(driver, source) + case sqlCommon: + source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() + dbSql = value + } - switch value := args[0].(type) { - case string: - var driver = dialect - if len(args) == 1 { - source = value - } else if len(args) >= 2 { - driver = value - source = args[1].(string) + db = DB{ + dialect: NewDialect(dialect), + logger: defaultLogger, + callback: DefaultCallback, + source: source, + values: map[string]interface{}{}, + db: dbSql, } - dbSql, err = sql.Open(driver, source) - case sqlCommon: - source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() - dbSql = value + db.parent = &db } - db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", - logger: defaultLogger, callback: DefaultCallback, source: source, - values: map[string]interface{}{}} - db.db = dbSql - db.parent = &db - return db, err } @@ -84,7 +88,7 @@ func (s *DB) New() *DB { return clone } -// Return the underlying sql.DB or sql.Tx instance. +// CommonDB Return the underlying sql.DB or sql.Tx instance. // Use of this method is discouraged. It's mainly intended to allow // coexistence with legacy non-GORM code. func (s *DB) CommonDB() sqlCommon { @@ -96,16 +100,12 @@ func (s *DB) Callback() *callback { return s.parent.callback } -func (s *DB) SetTagIdentifier(str string) { - s.parent.tagIdentifier = str -} - func (s *DB) SetLogger(l logger) { s.parent.logger = l } -func (s *DB) LogMode(b bool) *DB { - if b { +func (s *DB) LogMode(enable bool) *DB { + if enable { s.logMode = 2 } else { s.logMode = 1 @@ -113,8 +113,8 @@ func (s *DB) LogMode(b bool) *DB { return s } -func (s *DB) SingularTable(b bool) { - s.parent.singularTable = b +func (s *DB) SingularTable(enable bool) { + s.parent.singularTable = enable } func (s *DB) Where(query interface{}, args ...interface{}) *DB { @@ -158,11 +158,10 @@ func (s *DB) Joins(query string) *DB { } func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { - c := s for _, f := range funcs { - c = f(c) + s = f(s) } - return c + return s } func (s *DB) Unscoped() *DB { @@ -179,16 +178,14 @@ func (s *DB) Assign(attrs ...interface{}) *DB { func (s *DB) First(out interface{}, where ...interface{}) *DB { newScope := s.clone().NewScope(out) - newScope.Search = newScope.Search.clone() - newScope.Search.limit(1) + newScope.Search = newScope.Search.clone().limit(1) return newScope.InstanceSet("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } func (s *DB) Last(out interface{}, where ...interface{}) *DB { newScope := s.clone().NewScope(out) - newScope.Search = newScope.Search.clone() - newScope.Search.limit(1) + newScope.Search = newScope.Search.clone().limit(1) return newScope.InstanceSet("gorm:order_by_primary_key", "DESC"). inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } @@ -213,10 +210,9 @@ func (s *DB) Scan(dest interface{}) *DB { func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { c := s.clone() - r := c.First(out, where...) - if r.Error != nil { - if !r.RecordNotFound() { - return r + if result := c.First(out, where...); result.Error != nil { + if !result.RecordNotFound() { + return result } c.NewScope(out).inlineCondition(where...).initialize() } else { @@ -227,10 +223,9 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { c := s.clone() - r := c.First(out, where...) - if r.Error != nil { - if !r.RecordNotFound() { - return r + if result := c.First(out, where...); result.Error != nil { + if !result.RecordNotFound() { + return result } c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates) } else if len(c.search.AssignAttrs) > 0 { @@ -418,25 +413,24 @@ func (s *DB) RemoveIndex(indexName string) *DB { } func (s *DB) Association(column string) *Association { + var err error scope := s.clone().NewScope(s.Value) - primaryKey := scope.PrimaryKeyValue() - primaryType := scope.TableName() - if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) { - scope.Err(errors.New("primary key can't be nil")) - } - - var field *Field - var ok bool - if field, ok = scope.FieldByName(column); ok { - if field.Relationship == nil || field.Relationship.ForeignFieldName == "" { - scope.Err(fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())) - } + if primaryField := scope.PrimaryKeyField(); primaryField.IsBlank { + err = errors.New("primary key can't be nil") } else { - scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)) + if field, ok := scope.FieldByName(column); ok { + if field.Relationship == nil || field.Relationship.ForeignFieldName == "" { + err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) + } else { + return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field} + } + } else { + err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) + } } - return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, PrimaryType: primaryType, Field: field} + return &Association{Error: err} } func (s *DB) Preload(column string, conditions ...interface{}) *DB { diff --git a/model_struct.go b/model_struct.go index d7707f19c..77f3ec397 100644 --- a/model_struct.go +++ b/model_struct.go @@ -133,137 +133,134 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // Set fields for i := 0; i < scopeType.NumField(); i++ { - fieldStruct := scopeType.Field(i) - if !ast.IsExported(fieldStruct.Name) { - continue - } - - field := &StructField{ - Struct: fieldStruct, - Name: fieldStruct.Name, - Names: []string{fieldStruct.Name}, - Tag: fieldStruct.Tag, - } - - if fieldStruct.Tag.Get("sql") == "-" { - field.IsIgnored = true - } else { - sqlSettings := parseTagSetting(field.Tag.Get("sql")) - gormSettings := parseTagSetting(field.Tag.Get("gorm")) - if _, ok := gormSettings["PRIMARY_KEY"]; ok { - field.IsPrimaryKey = true - modelStruct.PrimaryKeyField = field - } - - if value, ok := sqlSettings["DEFAULT"]; ok { - field.DefaultValue = &value + if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) { + field := &StructField{ + Struct: fieldStruct, + Name: fieldStruct.Name, + Names: []string{fieldStruct.Name}, + Tag: fieldStruct.Tag, } - if value, ok := gormSettings["COLUMN"]; ok { - field.DBName = value + if fieldStruct.Tag.Get("sql") == "-" { + field.IsIgnored = true } else { - field.DBName = ToSnake(fieldStruct.Name) - } - - fieldType, indirectType := fieldStruct.Type, fieldStruct.Type - if indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner { - field.IsScanner, field.IsNormal = true, true - } - - if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { - field.IsTime, field.IsNormal = true, true - } + sqlSettings := parseTagSetting(field.Tag.Get("sql")) + gormSettings := parseTagSetting(field.Tag.Get("gorm")) + if _, ok := gormSettings["PRIMARY_KEY"]; ok { + field.IsPrimaryKey = true + modelStruct.PrimaryKeyField = field + } - many2many := gormSettings["MANY2MANY"] - foreignKey := SnakeToUpperCamel(gormSettings["FOREIGNKEY"]) - foreignType := SnakeToUpperCamel(gormSettings["FOREIGNTYPE"]) - associationForeignKey := SnakeToUpperCamel(gormSettings["ASSOCIATIONFOREIGNKEY"]) - if polymorphic := SnakeToUpperCamel(gormSettings["POLYMORPHIC"]); polymorphic != "" { - foreignKey = polymorphic + "Id" - foreignType = polymorphic + "Type" - } + if value, ok := sqlSettings["DEFAULT"]; ok { + field.DefaultValue = &value + } - if !field.IsNormal { - switch indirectType.Kind() { - case reflect.Slice: - typ := indirectType.Elem() - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } + if value, ok := gormSettings["COLUMN"]; ok { + field.DBName = value + } else { + field.DBName = ToSnake(fieldStruct.Name) + } - if typ.Kind() == reflect.Struct { - kind := "has_many" + fieldType, indirectType := fieldStruct.Type, fieldStruct.Type + if indirectType.Kind() == reflect.Ptr { + indirectType = indirectType.Elem() + } - if foreignKey == "" { - foreignKey = scopeType.Name() + "Id" - } + if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner { + field.IsScanner, field.IsNormal = true, true + } - if associationForeignKey == "" { - associationForeignKey = typ.Name() + "Id" - } + if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { + field.IsTime, field.IsNormal = true, true + } - if many2many != "" { - kind = "many_to_many" - } else if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - foreignKey = "" - } + many2many := gormSettings["MANY2MANY"] + foreignKey := SnakeToUpperCamel(gormSettings["FOREIGNKEY"]) + foreignType := SnakeToUpperCamel(gormSettings["FOREIGNTYPE"]) + associationForeignKey := SnakeToUpperCamel(gormSettings["ASSOCIATIONFOREIGNKEY"]) + if polymorphic := SnakeToUpperCamel(gormSettings["POLYMORPHIC"]); polymorphic != "" { + foreignKey = polymorphic + "Id" + foreignType = polymorphic + "Type" + } - field.Relationship = &Relationship{ - JoinTable: many2many, - ForeignType: foreignType, - ForeignFieldName: foreignKey, - AssociationForeignFieldName: associationForeignKey, - ForeignDBName: ToSnake(foreignKey), - AssociationForeignDBName: ToSnake(associationForeignKey), - Kind: kind, - } - } else { - field.IsNormal = true - } - case reflect.Struct: - if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - for _, field := range scope.New(reflect.New(indirectType).Interface()).GetStructFields() { - field.Names = append([]string{fieldStruct.Name}, field.Names...) - modelStruct.StructFields = append(modelStruct.StructFields, field) + if !field.IsNormal { + switch indirectType.Kind() { + case reflect.Slice: + typ := indirectType.Elem() + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() } - break - } else { - var belongsToForeignKey, hasOneForeignKey, kind string - if foreignKey == "" { - belongsToForeignKey = field.Name + "Id" - hasOneForeignKey = scopeType.Name() + "Id" + if typ.Kind() == reflect.Struct { + kind := "has_many" + + if foreignKey == "" { + foreignKey = scopeType.Name() + "Id" + } + + if associationForeignKey == "" { + associationForeignKey = typ.Name() + "Id" + } + + if many2many != "" { + kind = "many_to_many" + } else if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { + foreignKey = "" + } + + field.Relationship = &Relationship{ + JoinTable: many2many, + ForeignType: foreignType, + ForeignFieldName: foreignKey, + AssociationForeignFieldName: associationForeignKey, + ForeignDBName: ToSnake(foreignKey), + AssociationForeignDBName: ToSnake(associationForeignKey), + Kind: kind, + } } else { - belongsToForeignKey = foreignKey - hasOneForeignKey = foreignKey + field.IsNormal = true } - - if _, ok := scopeType.FieldByName(belongsToForeignKey); ok { - kind = "belongs_to" - foreignKey = belongsToForeignKey + case reflect.Struct: + if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + for _, field := range scope.New(reflect.New(indirectType).Interface()).GetStructFields() { + field.Names = append([]string{fieldStruct.Name}, field.Names...) + modelStruct.StructFields = append(modelStruct.StructFields, field) + } + break } else { - foreignKey = hasOneForeignKey - kind = "has_one" + var belongsToForeignKey, hasOneForeignKey, kind string + + if foreignKey == "" { + belongsToForeignKey = field.Name + "Id" + hasOneForeignKey = scopeType.Name() + "Id" + } else { + belongsToForeignKey = foreignKey + hasOneForeignKey = foreignKey + } + + if _, ok := scopeType.FieldByName(belongsToForeignKey); ok { + kind = "belongs_to" + foreignKey = belongsToForeignKey + } else { + foreignKey = hasOneForeignKey + kind = "has_one" + } + + field.Relationship = &Relationship{ + ForeignFieldName: foreignKey, + ForeignDBName: ToSnake(foreignKey), + ForeignType: foreignType, + Kind: kind, + } } - field.Relationship = &Relationship{ - ForeignFieldName: foreignKey, - ForeignDBName: ToSnake(foreignKey), - ForeignType: foreignType, - Kind: kind, - } + default: + field.IsNormal = true } - - default: - field.IsNormal = true } } + modelStruct.StructFields = append(modelStruct.StructFields, field) } - modelStruct.StructFields = append(modelStruct.StructFields, field) } for _, field := range modelStruct.StructFields { diff --git a/mssql.go b/mssql.go index e1b157e02..1d1562eae 100644 --- a/mssql.go +++ b/mssql.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "strings" + "time" ) type mssql struct{} @@ -36,7 +37,7 @@ func (s *mssql) SqlTag(value reflect.Value, size int) string { } return "text" case reflect.Struct: - if value.Type() == timeType { + if _, ok := value.Interface().(time.Time); ok { return "datetime2" } default: diff --git a/mysql.go b/mysql.go index 044c23064..7d9758a16 100644 --- a/mysql.go +++ b/mysql.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "strings" + "time" ) type mysql struct{} @@ -36,7 +37,7 @@ func (s *mysql) SqlTag(value reflect.Value, size int) string { } return "longtext" case reflect.Struct: - if value.Type() == timeType { + if _, ok := value.Interface().(time.Time); ok { return "timestamp NULL" } default: diff --git a/postgres.go b/postgres.go index 8b8e6a934..3654ddd7f 100644 --- a/postgres.go +++ b/postgres.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "fmt" "reflect" + "time" "github.com/lib/pq/hstore" ) @@ -40,7 +41,7 @@ func (s *postgres) SqlTag(value reflect.Value, size int) string { } return "text" case reflect.Struct: - if value.Type() == timeType { + if _, ok := value.Interface().(time.Time); ok { return "timestamp with time zone" } case reflect.Map: diff --git a/preload.go b/preload.go index 7ac0db0d9..eea75259e 100644 --- a/preload.go +++ b/preload.go @@ -7,7 +7,7 @@ import ( "reflect" ) -func getFieldValue(value reflect.Value, field string) interface{} { +func getRealValue(value reflect.Value, field string) interface{} { result := reflect.Indirect(value).FieldByName(field).Interface() if r, ok := result.(driver.Valuer); ok { result, _ = r.Value() @@ -20,26 +20,14 @@ func equalAsString(a interface{}, b interface{}) bool { } func Preload(scope *Scope) { - // Get Fields - var fields map[string]*Field - var isSlice bool - if scope.IndirectValue().Kind() == reflect.Slice { - isSlice = true - typ := scope.IndirectValue().Type().Elem() - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - elem := reflect.New(typ).Elem() - fields = scope.New(elem.Addr().Interface()).Fields() - } else { - fields = scope.Fields() - } + fields := scope.Fields() + isSlice := scope.IndirectValue().Kind() == reflect.Slice if scope.Search.Preload != nil { for key, conditions := range scope.Search.Preload { for _, field := range fields { if field.Name == key && field.Relationship != nil { - results := makeSlice(field.Field) + results := makeSlice(field.Struct.Type) relation := field.Relationship primaryName := scope.PrimaryKeyField().Name associationPrimaryKey := scope.New(results).PrimaryKeyField().Name @@ -53,10 +41,10 @@ func Preload(scope *Scope) { for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) if isSlice { - value := getFieldValue(result, relation.ForeignFieldName) + value := getRealValue(result, relation.ForeignFieldName) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { - if equalAsString(getFieldValue(objects.Index(j), primaryName), value) { + if equalAsString(getRealValue(objects.Index(j), primaryName), value) { reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) break } @@ -72,11 +60,11 @@ func Preload(scope *Scope) { if isSlice { for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) - value := getFieldValue(result, relation.ForeignFieldName) + value := getRealValue(result, relation.ForeignFieldName) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - if equalAsString(getFieldValue(object, primaryName), value) { + if equalAsString(getRealValue(object, primaryName), value) { f := object.FieldByName(field.Name) f.Set(reflect.Append(f, result)) break @@ -92,11 +80,11 @@ func Preload(scope *Scope) { for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) if isSlice { - value := getFieldValue(result, associationPrimaryKey) + value := getRealValue(result, associationPrimaryKey) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - if equalAsString(getFieldValue(object, relation.ForeignFieldName), value) { + if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { object.FieldByName(field.Name).Set(result) } } @@ -116,9 +104,8 @@ func Preload(scope *Scope) { } } -func makeSlice(value reflect.Value) interface{} { - typ := value.Type() - if value.Kind() == reflect.Slice { +func makeSlice(typ reflect.Type) interface{} { + if typ.Kind() == reflect.Slice { typ = typ.Elem() } sliceType := reflect.SliceOf(typ) @@ -132,8 +119,7 @@ func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) switch values.Kind() { case reflect.Slice: for i := 0; i < values.Len(); i++ { - value := values.Index(i) - primaryKeys = append(primaryKeys, reflect.Indirect(value).FieldByName(column).Interface()) + primaryKeys = append(primaryKeys, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) } case reflect.Struct: return []interface{}{values.FieldByName(column).Interface()} diff --git a/scope.go b/scope.go index 034f9e03d..db26c677a 100644 --- a/scope.go +++ b/scope.go @@ -111,7 +111,8 @@ func (scope *Scope) PrimaryKey() string { // PrimaryKeyZero check the primary key is blank or not func (scope *Scope) PrimaryKeyZero() bool { - return isBlank(reflect.ValueOf(scope.PrimaryKeyValue())) + field := scope.PrimaryKeyField() + return field == nil || field.IsBlank } // PrimaryKeyValue get the primary key's value @@ -125,35 +126,23 @@ func (scope *Scope) PrimaryKeyValue() interface{} { // HasColumn to check if has column func (scope *Scope) HasColumn(column string) bool { for _, field := range scope.GetStructFields() { - if !field.IsIgnored { - if field.Name == column || field.DBName == column { - return true - } + if field.IsNormal && (field.Name == column || field.DBName == column) { + return true } } return false } -// FieldValueByName to get column's value and existence -func (scope *Scope) FieldValueByName(name string) (interface{}, error) { - return FieldValueByName(name, scope.Value) -} - // SetColumn to set the column's value func (scope *Scope) SetColumn(column interface{}, value interface{}) error { if field, ok := column.(*Field); ok { return field.Set(value) } else if dbName, ok := column.(string); ok { - if scope.Value == nil { - return errors.New("scope value must not be nil for string columns") - } - if field, ok := scope.Fields()[dbName]; ok { return field.Set(value) } dbName = ToSnake(dbName) - if field, ok := scope.Fields()[dbName]; ok { return field.Set(value) } @@ -204,45 +193,11 @@ func (scope *Scope) AddToVars(value interface{}) string { } // TableName get table name - func (scope *Scope) TableName() string { if scope.Search != nil && len(scope.Search.TableName) > 0 { return scope.Search.TableName } - - if scope.Value == nil { - scope.Err(errors.New("can't get table name")) - return "" - } - - data := scope.IndirectValue() - if data.Kind() == reflect.Slice { - elem := data.Type().Elem() - if elem.Kind() == reflect.Ptr { - elem = elem.Elem() - } - data = reflect.New(elem).Elem() - } - - if fm := data.MethodByName("TableName"); fm.IsValid() { - if v := fm.Call([]reflect.Value{}); len(v) > 0 { - if result, ok := v[0].Interface().(string); ok { - return result - } - } - } - - str := ToSnake(data.Type().Name()) - - if scope.db == nil || !scope.db.parent.singularTable { - for index, reg := range pluralMapKeys { - if reg.MatchString(str) { - return reg.ReplaceAllString(str, pluralMapValues[index]) - } - } - } - - return str + return scope.GetModelStruct().TableName } func (scope *Scope) QuotedTableName() string { @@ -284,8 +239,7 @@ func (scope *Scope) Exec() *Scope { defer scope.Trace(NowFunc()) if !scope.HasError() { - result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...) - if scope.Err(err) == nil { + if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { if count, err := result.RowsAffected(); err == nil { scope.db.RowsAffected = count } diff --git a/scope_private.go b/scope_private.go index 217df03f3..5c125b0a0 100644 --- a/scope_private.go +++ b/scope_private.go @@ -25,10 +25,8 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri } else if value != "" { str = fmt.Sprintf("(%v)", value) } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: return scope.primaryCondition(scope.AddToVars(value)) - case sql.NullInt64: - return scope.primaryCondition(scope.AddToVars(value.Int64)) case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey())) clause["args"] = []interface{}{value} @@ -71,12 +69,14 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { var notEqualSql string + var primaryKey = scope.PrimaryKey() switch value := clause["query"].(type) { case string: + // is number if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), id) + return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { str = fmt.Sprintf(" NOT (%v) ", value) notEqualSql = fmt.Sprintf("NOT (%v)", value) @@ -84,15 +84,14 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), value) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: + return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value) case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v not in (?))", scope.Quote(scope.PrimaryKey())) + str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey)) clause["args"] = []interface{}{value} - } else { - return "" } + return "" case map[string]interface{}: var sqls []string for key, value := range value { @@ -157,16 +156,10 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) return } -func (scope *Scope) where(where ...interface{}) { - if len(where) > 0 { - scope.Search = scope.Search.clone().where(where[0], where[1:]...) - } -} - func (scope *Scope) whereSql() (sql string) { var primaryConditions, andConditions, orConditions []string - if !scope.Search.Unscope && scope.HasColumn("DeletedAt") { + if !scope.Search.Unscope && scope.Fields()["deleted_at"] != nil { sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName()) primaryConditions = append(primaryConditions, sql) } @@ -317,41 +310,19 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { } func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { - data := scope.IndirectValue() - if !data.CanAddr() { + if !scope.IndirectValue().CanAddr() { return values, true } + fields := scope.Fields() for key, value := range values { - if field, ok := scope.FieldByName(SnakeToUpperCamel(key)); ok && field.Field.IsValid() { - func() { - defer func() { - if err := recover(); err != nil { - hasUpdate = true - field.Set(value) - } - }() - - if field.Field.Interface() != value { - switch field.Field.Kind() { - case reflect.Int, reflect.Int32, reflect.Int64: - if s, ok := value.(string); ok { - i, err := strconv.Atoi(s) - if scope.Err(err) == nil { - value = i - } - } - - if field.Field.Int() != reflect.ValueOf(value).Int() { - hasUpdate = true - field.Set(value) - } - default: - hasUpdate = true - field.Set(value) - } + if field, ok := fields[ToSnake(key)]; ok && field.Field.IsValid() { + if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { + if !equalAsString(field.Field.Interface(), value) { + hasUpdate = true + field.Set(value) } - }() + } } } return @@ -464,7 +435,7 @@ func (scope *Scope) createJoinTable(field *StructField) { if field.Relationship != nil && field.Relationship.JoinTable != "" { if !scope.Dialect().HasTable(scope, field.Relationship.JoinTable) { newScope := scope.db.NewScope("") - primaryKeySqlType := scope.Dialect().SqlTag(reflect.ValueOf(scope.PrimaryKeyValue()), 255) + primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255) newScope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", field.Relationship.JoinTable, strings.Join([]string{ @@ -523,14 +494,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { var table = scope.TableName() var keyName = fmt.Sprintf("%s_%s_foreign", table, field) - var query = ` - ALTER TABLE %s - ADD CONSTRAINT %s - FOREIGN KEY (%s) - REFERENCES %s - ON DELETE %s - ON UPDATE %s; - ` + var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` scope.Raw(fmt.Sprintf(query, table, keyName, field, dest, onDelete, onUpdate)).Exec() } diff --git a/sqlite3.go b/sqlite3.go index 1fc47b446..c92e2cdb0 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -3,6 +3,7 @@ package gorm import ( "fmt" "reflect" + "time" ) type sqlite3 struct{} @@ -35,7 +36,7 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int) string { } return "text" case reflect.Struct: - if value.Type() == timeType { + if _, ok := value.Interface().(time.Time); ok { return "datetime" } default: diff --git a/utils.go b/utils.go index 14862d96b..182b7376f 100644 --- a/utils.go +++ b/utils.go @@ -2,9 +2,6 @@ package gorm import ( "bytes" - "errors" - "fmt" - "reflect" "strings" "sync" ) @@ -26,23 +23,6 @@ func (s *safeMap) Get(key string) string { return s.m[key] } -func FieldValueByName(name string, value interface{}) (i interface{}, err error) { - data := reflect.Indirect(reflect.ValueOf(value)) - name = SnakeToUpperCamel(name) - - if data.Kind() == reflect.Struct { - if field := data.FieldByName(name); field.IsValid() { - i = field.Interface() - } else { - return nil, fmt.Errorf("struct has no field with name %s", name) - } - } else { - return nil, errors.New("value must be of kind struct") - } - - return -} - func newSafeMap() *safeMap { return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} } diff --git a/utils_private.go b/utils_private.go index 0b46cfdb2..1ebd49fe4 100644 --- a/utils_private.go +++ b/utils_private.go @@ -2,18 +2,16 @@ package gorm import ( "fmt" - "os" "reflect" "regexp" "runtime" - "strings" ) func fileWithLineNum() string { for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { - return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line) + return fmt.Sprintf("%v:%v", file, line) } } return ""