Skip to content

Commit

Permalink
multpile foreign keys
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jul 30, 2015
1 parent 82d726b commit a29230c
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 86 deletions.
8 changes: 5 additions & 3 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ func Create(scope *Scope) {
}
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) {
columns = append(columns, scope.Quote(relationField.DBName))
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
for _, dbName := range relationship.ForeignDBNames {
if relationField := fields[dbName]; !scope.changeableField(relationField) {
columns = append(columns, scope.Quote(relationField.DBName))
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
}
}
}
}
Expand Down
27 changes: 21 additions & 6 deletions callback_shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ func SaveBeforeAssociations(scope *Scope) {
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
value := field.Field
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
if relationship.ForeignFieldName != "" {
scope.Err(scope.SetColumn(relationship.ForeignFieldName, scope.New(value.Addr().Interface()).PrimaryKeyValue()))
if len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok {
scope.Err(scope.SetColumn(fieldName, f.Field.Interface()))
}
}
}
}
}
Expand All @@ -44,8 +49,13 @@ func SaveAfterAssociations(scope *Scope) {
elem := value.Index(i).Addr().Interface()
newScope := newDB.NewScope(elem)

if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" {
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok {
scope.Err(scope.SetColumn(fieldName, f.Field.Interface()))
}
}
}

if relationship.PolymorphicType != "" {
Expand All @@ -61,8 +71,13 @@ func SaveAfterAssociations(scope *Scope) {
default:
elem := value.Addr().Interface()
newScope := scope.New(elem)
if relationship.ForeignFieldName != "" {
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
if len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok {
scope.Err(scope.SetColumn(fieldName, f.Field.Interface()))
}
}
}

if relationship.PolymorphicType != "" {
Expand Down
7 changes: 4 additions & 3 deletions callback_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ func Update(scope *Scope) {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) {
if !relationField.IsBlank {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface())))
for _, dbName := range relationship.ForeignDBNames {
if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank {
sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))
sqls = append(sqls, sql)
}
}
}
Expand Down
35 changes: 6 additions & 29 deletions join_table_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,41 +45,18 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
s.TableName = tableName

s.Source = JoinTableSource{ModelType: source}
sourceScope := &Scope{Value: reflect.New(source).Interface()}
sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields
for _, primaryField := range sourcePrimaryFields {
if relationship.ForeignDBName == "" {
relationship.ForeignFieldName = source.Name() + primaryField.Name
relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName)
}

var dbName string
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
dbName = relationship.ForeignDBName
} else {
dbName = ToDBName(source.Name() + primaryField.Name)
}

for idx, dbName := range relationship.ForeignFieldNames {
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
DBName: dbName,
AssociationDBName: primaryField.DBName,
DBName: relationship.ForeignDBNames[idx],
AssociationDBName: dbName,
})
}

s.Destination = JoinTableSource{ModelType: destination}
destinationScope := &Scope{Value: reflect.New(destination).Interface()}
destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields
for _, primaryField := range destinationPrimaryFields {
var dbName string
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
dbName = relationship.AssociationForeignDBName
} else {
dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name)
}

for idx, dbName := range relationship.AssociationForeignFieldNames {
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
DBName: dbName,
AssociationDBName: primaryField.DBName,
DBName: relationship.AssociationForeignDBNames[idx],
AssociationDBName: dbName,
})
}
}
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ func (s *DB) Association(column string) *Association {
err = errors.New("primary key can't be nil")
} else {
if field, ok := scope.FieldByName(column); ok {
if field.Relationship == nil || field.Relationship.ForeignFieldName == "" {
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
} else {
return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field}
Expand Down
134 changes: 90 additions & 44 deletions model_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ func (structField *StructField) clone() *StructField {
}

type Relationship struct {
Kind string
PolymorphicType string
PolymorphicDBName string
ForeignFieldName string
ForeignDBName string
AssociationForeignFieldName string
AssociationForeignDBName string
JoinTableHandler JoinTableHandlerInterface
Kind string
PolymorphicType string
PolymorphicDBName string
ForeignFieldNames []string
ForeignDBNames []string
AssociationForeignFieldNames []string
AssociationForeignDBNames []string
JoinTableHandler JoinTableHandlerInterface
}

var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
Expand Down Expand Up @@ -190,12 +190,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct {

var relationship = &Relationship{}

foreignKey := gormSettings["FOREIGNKEY"]
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
relationship.ForeignFieldName = polymorphicField.Name
relationship.ForeignDBName = polymorphicField.DBName
relationship.ForeignFieldNames = []string{polymorphicField.Name}
relationship.ForeignDBNames = []string{polymorphicField.DBName}
relationship.PolymorphicType = polymorphicType.Name
relationship.PolymorphicDBName = polymorphicType.DBName
polymorphicType.IsForeignKey = true
Expand All @@ -204,6 +203,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
}

var foreignKeys []string
if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok {
foreignKeys := append(foreignKeys, gormSettings["FOREIGNKEY"])
}
switch indirectType.Kind() {
case reflect.Slice:
elemType := indirectType.Elem()
Expand All @@ -212,34 +215,63 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}

if elemType.Kind() == reflect.Struct {
if foreignKey == "" {
foreignKey = scopeType.Name() + "Id"
}

if many2many := gormSettings["MANY2MANY"]; many2many != "" {
relationship.Kind = "many_to_many"
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
if associationForeignKey == "" {
associationForeignKey = elemType.Name() + "Id"

// foreign keys
if len(foreignKeys) == 0 {
for _, field := range scope.PrimaryFields() {
foreignKeys = append(foreignKeys, field.DBName)
}
}

relationship.ForeignFieldName = foreignKey
relationship.ForeignDBName = ToDBName(foreignKey)
relationship.AssociationForeignFieldName = associationForeignKey
relationship.AssociationForeignDBName = ToDBName(associationForeignKey)
for _, foreignKey := range foreignKeys {
if field, ok := scope.FieldByName(foreignKey); ok {
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
}
}

// association foreign keys
var associationForeignKeys []string
if foreignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
associationForeignKeys = []string{gormSettings["ASSOCIATIONFOREIGNKEY"]}
} else {
for _, field := range toScope.PrimaryFields() {
associationForeignKeys = append(associationForeignKeys, field.DBName)
}
}

for _, name := range associationForeignKeys {
if field, ok := toScope.FieldByName(name); ok {
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, name)
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
}
}

joinTableHandler := JoinTableHandler{}
joinTableHandler.Setup(relationship, many2many, scopeType, elemType)
relationship.JoinTableHandler = &joinTableHandler
field.Relationship = relationship
} else {
if len(foreignKeys) == 0 {
for _, field := range scope.PrimaryFields() {
foreignKeys = append(foreignKeys, scopeType.Name()+field.Name)
}
}

relationship.Kind = "has_many"
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship
} else if relationship.ForeignFieldName != "" {
for _, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
foreignField.IsForeignKey = true
}
}

if len(relationship.ForeignFieldNames) != 0 {
field.Relationship = relationship
}
}
Expand All @@ -258,28 +290,42 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
continue
} else {
belongsToForeignKey := foreignKey
if belongsToForeignKey == "" {
belongsToForeignKey = field.Name + "Id"
belongsToForeignKeys := foreignKeys
if len(belongsToForeignKeys) == 0 {
for _, field := range toScope.PrimaryFields() {
belongsToForeignKeys = append(belongsToForeignKeys, field.Name+field.Name)
}
}

for _, foreignKey := range belongsToForeignKeys {
if foreignField := getForeignField(foreignKey, fields); foreignField != nil {
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
foreignField.IsForeignKey = true
}
}

if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil {
if len(relationship.ForeignFieldNames) != 0 {
relationship.Kind = "belongs_to"
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship
} else {
if foreignKey == "" {
foreignKey = modelStruct.ModelType.Name() + "Id"
hasOneForeignKeys := foreignKeys
if len(hasOneForeignKeys) == 0 {
for _, field := range toScope.PrimaryFields() {
hasOneForeignKeys = append(hasOneForeignKeys, modelStruct.ModelType.Name()+field.Name)
}
}
relationship.Kind = "has_one"
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship
} else if relationship.ForeignFieldName != "" {

for _, foreignKey := range hasOneForeignKeys {
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
foreignField.IsForeignKey = true
}
}

if len(relationship.ForeignFieldNames) != 0 {
relationship.Kind = "has_one"
field.Relationship = relationship
}
}
Expand Down

0 comments on commit a29230c

Please sign in to comment.