Skip to content

Commit

Permalink
First try for the Preload feature
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 11, 2015
1 parent 8aef600 commit 3b784c3
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 9 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1092,8 +1092,9 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111
db.Mode(&User{}).Do("EditForm").Get("edit_form_html")
DefaultTimeZone, R/W Splitting, Validation
* Github Pages
* Includes
* AlertColumn, DropColumn
* db.Preload("Addresses.Map", "active = ?", true).Preload("Profile").Find(&users)
* db.Find(&users).Related(&users)

# Author

Expand Down
39 changes: 38 additions & 1 deletion callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package gorm
import (
"fmt"
"reflect"

"github.com/jinzhu/gorm"
)

func Query(scope *Scope) {
Expand All @@ -13,6 +15,7 @@ func Query(scope *Scope) {
isPtr bool
anyRecordFound bool
destType reflect.Type
primaryKeys []interface{}
)

var dest = scope.IndirectValue()
Expand Down Expand Up @@ -47,8 +50,9 @@ func Query(scope *Scope) {
return
}

columns, _ := rows.Columns()
preloadMap := map[string]map[string]*gorm.Field{}

columns, _ := rows.Columns()
defer rows.Close()
for rows.Next() {
scope.db.RowsAffected += 1
Expand All @@ -62,13 +66,18 @@ func Query(scope *Scope) {
var values = make([]interface{}, len(columns))

fields := scope.New(elem.Addr().Interface()).Fields()
var primaryKey interface{}
for index, column := range columns {
if field, ok := fields[column]; ok {
if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface()
} else {
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
}
if field.IsPrimaryKey {
primaryKey = values[index]
primaryKeys = append(primaryKeys, primaryKey)
}
} else {
var value interface{}
values[index] = &value
Expand All @@ -95,6 +104,34 @@ func Query(scope *Scope) {
dest.Set(reflect.Append(dest, elem))
}
}

if scope.Search.Preload != nil {
for key := range scope.Search.Preload {
if field := fields[key]; field != nil {
if preloadMap[key] == nil {
preloadMap[key] = map[string]reflect.Value{}
}
preloadMap[key][fmt.Sprintf("%v", primaryKey)] = field
}
}
}
}

for _, value := range preloadMap {
var typ reflect.Type
var relation *Relation
for _, v := range value {
typ = v.Field.Type()
relation = v.Relationship
break
}
sliceType := reflect.SliceOf(typ)
slice := reflect.MakeSlice(sliceType, 0, 0)
slicePtr := reflect.New(sliceType)
slicePtr.Elem().Set(slice)
if relation == "has_many" {
scope.NewDB().Find(slicePtr.Interface(), primaryKeys)
}
}

if !anyRecordFound && !isSlice {
Expand Down
8 changes: 4 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,6 @@ func (s *DB) Joins(query string) *DB {
return s.clone().search.joins(query).db
}

func (s *DB) Includes(value interface{}) *DB {
return s.clone().search.includes(value).db
}

func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
c := s
for _, f := range funcs {
Expand Down Expand Up @@ -432,6 +428,10 @@ func (s *DB) Association(column string) *Association {
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, PrimaryType: primaryType, Field: field}
}

func (s *DB) Preload(column string, conditions ...interface{}) *DB {
return s.clone().search.preload(column, conditions...).db
}

// Set set value by name
func (s *DB) Set(name string, value interface{}) *DB {
return s.clone().InstantSet(name, value)
Expand Down
12 changes: 9 additions & 3 deletions search.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type search struct {
Orders []string
Joins string
Selects []map[string]interface{}
Preload map[string][]interface{}
Offset string
Limit string
Group string
Expand All @@ -23,6 +24,7 @@ type search struct {

func (s *search) clone() *search {
return &search{
Preload: s.Preload,
WhereConditions: s.WhereConditions,
OrConditions: s.OrConditions,
NotConditions: s.NotConditions,
Expand Down Expand Up @@ -100,12 +102,16 @@ func (s *search) having(query string, values ...interface{}) *search {
return s
}

func (s *search) includes(value interface{}) *search {
func (s *search) joins(query string) *search {
s.Joins = query
return s
}

func (s *search) joins(query string) *search {
s.Joins = query
func (s *search) preload(column string, values ...interface{}) *search {
if s.Preload == nil {
s.Preload = map[string][]interface{}{}
}
s.Preload[column] = values
return s
}

Expand Down

0 comments on commit 3b784c3

Please sign in to comment.