Skip to content

Commit

Permalink
Create with Select
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Mar 12, 2015
1 parent da7830e commit ad251b9
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 50 deletions.
20 changes: 15 additions & 5 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,21 @@ func Create(scope *Scope) {
if !scope.HasError() {
// set create sql
var sqls, columns []string
for _, field := range scope.Fields() {
if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) {
if !field.IsBlank || !field.HasDefaultValue {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
fields := scope.Fields()
for _, field := range fields {
if scope.ValidField(field) {
if field.IsNormal {
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
if !field.IsBlank || !field.HasDefaultValue {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
}
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
if relationField := fields[relationship.ForeignDBName]; !scope.ValidField(relationField) {
columns = append(columns, scope.Quote(relationField.DBName))
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions callback_shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func CommitOrRollbackTransaction(scope *Scope) {

func SaveBeforeAssociations(scope *Scope) {
for _, field := range scope.Fields() {
if !field.IsBlank && !field.IsIgnored {
if scope.ValidField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
value := field.Field
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
Expand All @@ -26,7 +26,7 @@ func SaveBeforeAssociations(scope *Scope) {

func SaveAfterAssociations(scope *Scope) {
for _, field := range scope.Fields() {
if !field.IsBlank && !field.IsIgnored {
if scope.ValidField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil &&
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
value := field.Field
Expand Down
20 changes: 20 additions & 0 deletions create_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gorm_test

import (
"fmt"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -121,3 +122,22 @@ func TestAnonymousField(t *testing.T) {
t.Errorf("Should be able to get anonymous field")
}
}

func TestSelectCreate(t *testing.T) {
user := getPreparedUser("user1", "select_create")
DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(&user)

var user2 User
DB.Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company").First(&user2, user.Id)

if user2.Name != user.Name || user2.Age == user.Age {
t.Errorf("Should only create users with name column")
}

fmt.Println(user2.CreditCard.ID)
if user2.BillingAddressID.Int64 == 0 || user2.ShippingAddressId != 0 ||
user2.CreditCard.ID == 0 || len(user2.Emails) == 0 {
t.Errorf("Should only create users with name column")
}
}
14 changes: 10 additions & 4 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ func TestExceptionsWithInvalidSql(t *testing.T) {
}

func TestSetTable(t *testing.T) {
if DB.Table("users").Pluck("age", &[]int{}).Error != nil {
t.Errorf("No errors should happen if set table for pluck")
DB.Create(getPreparedUser("pluck_user1", "pluck_user"))
DB.Create(getPreparedUser("pluck_user2", "pluck_user"))
DB.Create(getPreparedUser("pluck_user3", "pluck_user"))

if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
t.Errorf("No errors should happen if set table for pluck", err.Error())
}

var users []User
Expand All @@ -115,9 +119,11 @@ func TestSetTable(t *testing.T) {
t.Errorf("Query from specified table")
}

DB.Save(getPreparedUser("normal_user", "reset_table"))
DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table"))
var user1, user2, user3 User
DB.First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
if (user1.Name == user2.Name) || (user1.Name != user3.Name) {
DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") {
t.Errorf("unset specified table with blank string")
}
}
Expand Down
34 changes: 7 additions & 27 deletions preload_test.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,9 @@
package gorm_test

import (
"fmt"
"testing"
)

func getPreloadUser(name string) User {
var company Company
DB.Where(Company{Name: "preload"}).FirstOrCreate(&company)

return User{
Name: name,
Role: Role{"Preload"},
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
Emails: []Email{
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
},
Company: company,
Languages: []Language{
{Name: fmt.Sprintf("lang_1_%v", name)},
{Name: fmt.Sprintf("lang_2_%v", name)},
},
}
import "testing"

func getPreloadUser(name string) *User {
return getPreparedUser(name, "Preload")
}

func checkUserHasPreloadData(user User, t *testing.T) {
Expand Down Expand Up @@ -64,7 +44,7 @@ func checkUserHasPreloadData(user User, t *testing.T) {

func TestPreload(t *testing.T) {
user1 := getPreloadUser("user1")
DB.Save(&user1)
DB.Save(user1)

preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company")
Expand All @@ -73,10 +53,10 @@ func TestPreload(t *testing.T) {
checkUserHasPreloadData(user, t)

user2 := getPreloadUser("user2")
DB.Save(&user2)
DB.Save(user2)

user3 := getPreloadUser("user3")
DB.Save(&user3)
DB.Save(user3)

var users []User
preloadDB.Find(&users)
Expand Down
39 changes: 39 additions & 0 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,42 @@ func (scope *Scope) CommitOrRollback() *Scope {
}
return scope
}

func (scope *Scope) SelectAttrs() (attrs []string) {
for _, value := range scope.Search.selects {
if str, ok := value.(string); ok {
attrs = append(attrs, str)
} else if strs, ok := value.([]interface{}); ok {
for _, str := range strs {
attrs = append(attrs, fmt.Sprintf("%v", str))
}
}
}
return attrs
}

func (scope *Scope) OmitAttrs() []string {
return scope.Search.omits
}

func (scope *Scope) ValidField(field *Field) bool {
selectAttrs := scope.SelectAttrs()
omitAttrs := scope.OmitAttrs()

if len(selectAttrs) > 0 {
for _, attr := range selectAttrs {
if field.Name == attr || field.DBName == attr {
return true
}
}
return false
}

for _, attr := range omitAttrs {
if field.Name == attr || field.DBName == attr {
return false
}
}

return !field.IsIgnored
}
12 changes: 0 additions & 12 deletions search.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,6 @@ func (s *search) Omit(columns ...string) *search {
return s
}

func (s *search) SelectAttrs() (attrs []string) {
for key, value := range s.selects {
attrs = append(attrs, key)
attrs = append(attrs, value.([]string)...)
}
return attrs
}

func (s *search) OmitAttrs() []string {
return s.omits
}

func (s *search) Limit(value interface{}) *search {
s.limit = s.getInterfaceAsSql(value)
return s
Expand Down
23 changes: 23 additions & 0 deletions structs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"

"reflect"
"time"
Expand Down Expand Up @@ -194,3 +195,25 @@ func (nt NullTime) Value() (driver.Value, error) {
}
return nt.Time, nil
}

func getPreparedUser(name string, role string) *User {
var company Company
DB.Where(Company{Name: role}).FirstOrCreate(&company)

return &User{
Name: name,
Age: 20,
Role: Role{role},
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
Emails: []Email{
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
},
Company: company,
Languages: []Language{
{Name: fmt.Sprintf("lang_1_%v", name)},
{Name: fmt.Sprintf("lang_2_%v", name)},
},
}
}

0 comments on commit ad251b9

Please sign in to comment.