Skip to content

Commit

Permalink
Test ManyToMany relations with multi primary keys
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Apr 10, 2015
1 parent 67266eb commit 1eb1ed0
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 14 deletions.
1 change: 1 addition & 0 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func (association *Association) Append(values ...interface{}) *Association {
association.setErr(errors.New("invalid association type"))
}
}
scope.Search.Select(association.Column)
scope.callCallbacks(scope.db.parent.callback.updates)
return association.setErr(scope.db.Error)
}
Expand Down
16 changes: 9 additions & 7 deletions callback_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@ func Update(scope *Scope) {
}
}

scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v %v",
scope.QuotedTableName(),
strings.Join(sqls, ", "),
scope.CombinedConditionSql(),
))
scope.Exec()
if len(sqls) > 0 {
scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v %v",
scope.QuotedTableName(),
strings.Join(sqls, ", "),
scope.CombinedConditionSql(),
))
scope.Exec()
}
}
}

Expand Down
27 changes: 20 additions & 7 deletions join_table_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,39 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s

s.Source = JoinTableSource{ModelType: source}
sourceScope := &Scope{Value: reflect.New(source).Interface()}
for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields {
sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields
for _, primaryField := range sourcePrimaryFields {
if relationship.ForeignDBName == "" {
relationship.ForeignFieldName = source.Name() + primaryField.Name
relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName)
}

var dbName string
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
dbName = relationship.ForeignDBName
} else {
dbName = ToDBName(source.Name() + primaryField.Name)
}

s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
DBName: relationship.ForeignDBName,
DBName: dbName,
AssociationDBName: primaryField.DBName,
})
}

s.Destination = JoinTableSource{ModelType: destination}
destinationScope := &Scope{Value: reflect.New(destination).Interface()}
for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields {
if relationship.AssociationForeignDBName == "" {
relationship.AssociationForeignFieldName = destination.Name() + primaryField.Name
relationship.AssociationForeignDBName = ToDBName(relationship.AssociationForeignFieldName)
destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields
for _, primaryField := range destinationPrimaryFields {
var dbName string
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
dbName = relationship.AssociationForeignDBName
} else {
dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name)
}

s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
DBName: relationship.AssociationForeignDBName,
DBName: dbName,
AssociationDBName: primaryField.DBName,
})
}
Expand Down
46 changes: 46 additions & 0 deletions multi_primary_keys_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package gorm_test

import (
"fmt"
"os"
"testing"
)

type Blog struct {
ID uint `gorm:"primary_key"`
Locale string `gorm:"primary_key"`
Subject string
Body string
Tags []Tag `gorm:"many2many:blog_tags;"`
}

type Tag struct {
ID uint `gorm:"primary_key"`
Locale string `gorm:"primary_key"`
Value string
}

func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "sqlite" {
DB.Exec(fmt.Sprintf("drop table blog_tags;"))
DB.AutoMigrate(&Blog{}, &Tag{})
blog := Blog{
Locale: "ZH",
Subject: "subject",
Body: "body",
Tags: []Tag{
{Locale: "ZH", Value: "tag1"},
{Locale: "ZH", Value: "tag2"},
},
}

DB.Save(&blog)
DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}})

var tags []Tag
DB.Model(&blog).Related(&tags, "Tags")
if len(tags) != 3 {
t.Errorf("should found 3 tags with blog")
}
}
}

0 comments on commit 1eb1ed0

Please sign in to comment.