diff --git a/association.go b/association.go index 89bb1bec8..37e10516d 100644 --- a/association.go +++ b/association.go @@ -40,6 +40,7 @@ func (association *Association) Append(values ...interface{}) *Association { association.setErr(errors.New("invalid association type")) } } + scope.Search.Select(association.Column) scope.callCallbacks(scope.db.parent.callback.updates) return association.setErr(scope.db.Error) } diff --git a/callback_update.go b/callback_update.go index 1167871cb..c3f7b4b62 100644 --- a/callback_update.go +++ b/callback_update.go @@ -64,13 +64,15 @@ func Update(scope *Scope) { } } - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v %v", - scope.QuotedTableName(), - strings.Join(sqls, ", "), - scope.CombinedConditionSql(), - )) - scope.Exec() + if len(sqls) > 0 { + scope.Raw(fmt.Sprintf( + "UPDATE %v SET %v %v", + scope.QuotedTableName(), + strings.Join(sqls, ", "), + scope.CombinedConditionSql(), + )) + scope.Exec() + } } } diff --git a/join_table_handler.go b/join_table_handler.go index 9f705564b..e589d6f59 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -36,26 +36,39 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s s.Source = JoinTableSource{ModelType: source} sourceScope := &Scope{Value: reflect.New(source).Interface()} - for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { + sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields + for _, primaryField := range sourcePrimaryFields { if relationship.ForeignDBName == "" { relationship.ForeignFieldName = source.Name() + primaryField.Name relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName) } + + var dbName string + if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" { + dbName = relationship.ForeignDBName + } else { + dbName = ToDBName(source.Name() + primaryField.Name) + } + s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.ForeignDBName, + DBName: dbName, AssociationDBName: primaryField.DBName, }) } s.Destination = JoinTableSource{ModelType: destination} destinationScope := &Scope{Value: reflect.New(destination).Interface()} - for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { - if relationship.AssociationForeignDBName == "" { - relationship.AssociationForeignFieldName = destination.Name() + primaryField.Name - relationship.AssociationForeignDBName = ToDBName(relationship.AssociationForeignFieldName) + destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields + for _, primaryField := range destinationPrimaryFields { + var dbName string + if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" { + dbName = relationship.AssociationForeignDBName + } else { + dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name) } + s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.AssociationForeignDBName, + DBName: dbName, AssociationDBName: primaryField.DBName, }) } diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go new file mode 100644 index 000000000..4aa3517eb --- /dev/null +++ b/multi_primary_keys_test.go @@ -0,0 +1,46 @@ +package gorm_test + +import ( + "fmt" + "os" + "testing" +) + +type Blog struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` +} + +type Tag struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Value string +} + +func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "sqlite" { + DB.Exec(fmt.Sprintf("drop table blog_tags;")) + DB.AutoMigrate(&Blog{}, &Tag{}) + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + Tags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + + DB.Save(&blog) + DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}}) + + var tags []Tag + DB.Model(&blog).Related(&tags, "Tags") + if len(tags) != 3 { + t.Errorf("should found 3 tags with blog") + } + } +}