Skip to content

Commit

Permalink
Fix tests for postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Aug 30, 2014
1 parent e9ecf9c commit 6271cf0
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 55 deletions.
2 changes: 1 addition & 1 deletion callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func Create(scope *Scope) {
var sqls, columns []string

for _, field := range scope.Fields() {
if len(field.SqlTag) > 0 && !field.IsIgnored && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) {
if field.IsNormal && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Value))
}
Expand Down
8 changes: 3 additions & 5 deletions callback_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,14 @@ func AssignUpdateAttributes(scope *Scope) {
}

func BeforeUpdate(scope *Scope) {
_, ok := scope.Get("gorm:update_column")
if !ok {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethod("BeforeSave")
scope.CallMethod("BeforeUpdate")
}
}

func UpdateTimeStampWhenUpdate(scope *Scope) {
_, ok := scope.Get("gorm:update_column")
if !ok {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc())
}
}
Expand All @@ -50,7 +48,7 @@ func Update(scope *Scope) {
}
} else {
for _, field := range scope.Fields() {
if !field.IsPrimaryKey && len(field.SqlTag) > 0 && !field.IsIgnored {
if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value)))
}
}
Expand Down
2 changes: 1 addition & 1 deletion field.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ type Field struct {
Field reflect.Value
Value interface{}
Tag reflect.StructTag
SqlTag string
Relationship *relationship
IsNormal bool
IsBlank bool
IsIgnored bool
IsPrimaryKey bool
Expand Down
53 changes: 46 additions & 7 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,20 @@ func (scope *Scope) PrimaryKey() string {
return scope.primaryKey
}

scope.primaryKey = ToSnake(GetPrimaryKey(scope.Value))
var indirectValue = scope.IndirectValue()

clone := scope
if indirectValue.Kind() == reflect.Slice {
clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface())
}

for _, field := range clone.Fields() {
if field.IsPrimaryKey {
scope.primaryKey = field.DBName
break
}
}

return scope.primaryKey
}

Expand Down Expand Up @@ -130,8 +143,12 @@ func (scope *Scope) SetColumn(column string, value interface{}) bool {
if scope.Value == nil {
return false
}

return setFieldValue(scope.IndirectValue().FieldByName(SnakeToUpperCamel(column)), value)
for _, field := range scope.Fields() {
if field.Name == column || field.DBName == column {
return setFieldValue(field.Field, value)
}
}
return false
}

// CallMethod invoke method with necessary argument
Expand Down Expand Up @@ -262,13 +279,19 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {

// Search for primary key tag identifier
settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))

if scope.PrimaryKey() == field.DBName {
if _, ok := settings["PRIMARY_KEY"]; ok {
field.IsPrimaryKey = true
}

field.Tag = fieldStruct.Tag
field.SqlTag = scope.sqlTagForField(&field)

tagIdentifier := "sql"
if scope.db != nil {
tagIdentifier = scope.db.parent.tagIdentifier
}
if fieldStruct.Tag.Get(tagIdentifier) == "-" {
field.IsIgnored = true
}

if !field.IsIgnored {
// parse association
Expand Down Expand Up @@ -311,6 +334,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
if many2many != "" {
field.Relationship.Kind = "many_to_many"
}
} else {
field.IsNormal = true
}
case reflect.Struct:
embedded := settings["EMBEDDED"]
Expand All @@ -321,7 +346,9 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
fields = append(fields, field)
}
return fields
} else if !field.IsTime() && !field.IsScanner() {
} else if field.IsTime() || field.IsScanner() {
field.IsNormal = true
} else {
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
} else if scope.HasColumn(foreignKey) {
Expand All @@ -335,6 +362,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
}
}
}
default:
field.IsNormal = true
}
}
return []*Field{&field}
Expand All @@ -345,19 +374,29 @@ func (scope *Scope) Fields() map[string]*Field {
var fields = map[string]*Field{}
if scope.IndirectValue().IsValid() {
scopeTyp := scope.IndirectValue().Type()
var hasPrimaryKey = false
for i := 0; i < scopeTyp.NumField(); i++ {
fieldStruct := scopeTyp.Field(i)
if !ast.IsExported(fieldStruct.Name) {
continue
}
for _, field := range scope.fieldFromStruct(fieldStruct) {
if field.IsPrimaryKey {
hasPrimaryKey = true
}
if _, ok := fields[field.DBName]; ok {
panic(fmt.Sprintf("Duplicated column name for %v (%v)\n", scope.typeName(), fileWithLineNum()))
} else {
fields[field.DBName] = field
}
}
}

if !hasPrimaryKey {
if field, ok := fields["id"]; ok {
field.IsPrimaryKey = true
}
}
}
return fields
}
Expand Down
15 changes: 6 additions & 9 deletions scope_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,6 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) {
var size = 255

fieldTag := field.Tag.Get(scope.db.parent.tagIdentifier)
if fieldTag == "-" {
field.IsIgnored = true
return
}

var setting = parseTagSetting(fieldTag)

if value, ok := setting["SIZE"]; ok {
Expand Down Expand Up @@ -481,8 +476,9 @@ func (scope *Scope) createJoinTable(field *Field) {
func (scope *Scope) createTable() *Scope {
var sqls []string
for _, field := range scope.Fields() {
if !field.IsIgnored && len(field.SqlTag) > 0 {
sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag)
if field.IsNormal {
sqlTag := scope.sqlTagForField(field)
sqls = append(sqls, scope.Quote(field.DBName)+" "+sqlTag)
}
scope.createJoinTable(field)
}
Expand Down Expand Up @@ -535,8 +531,9 @@ func (scope *Scope) autoMigrate() *Scope {
} else {
for _, field := range scope.Fields() {
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
if len(field.SqlTag) > 0 && !field.IsIgnored {
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, field.SqlTag)).Exec()
if field.IsNormal {
sqlTag := scope.sqlTagForField(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec()
}
}
scope.createJoinTable(field)
Expand Down
32 changes: 0 additions & 32 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gorm

import (
"bytes"
"go/ast"
"reflect"
"strings"
"sync"
Expand Down Expand Up @@ -91,37 +90,6 @@ func SnakeToUpperCamel(s string) string {
return u
}

func GetPrimaryKey(value interface{}) string {
var indirectValue = reflect.Indirect(reflect.ValueOf(value))

if indirectValue.Kind() == reflect.Slice {
indirectValue = reflect.New(indirectValue.Type().Elem()).Elem()
}

if indirectValue.IsValid() {
hasId := false
scopeTyp := indirectValue.Type()
for i := 0; i < scopeTyp.NumField(); i++ {
fieldStruct := scopeTyp.Field(i)
if !ast.IsExported(fieldStruct.Name) {
continue
}

settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))
if _, ok := settings["PRIMARY_KEY"]; ok {
return fieldStruct.Name
} else if fieldStruct.Name == "Id" {
hasId = true
}
}
if hasId {
return "Id"
}
}

return ""
}

func parseTagSetting(str string) map[string]string {
tags := strings.Split(str, ";")
setting := map[string]string{}
Expand Down

0 comments on commit 6271cf0

Please sign in to comment.