Skip to content

Commit

Permalink
support nested preloading
Browse files Browse the repository at this point in the history
  • Loading branch information
bom-d-van committed Apr 21, 2015
1 parent 055bf79 commit 6d58dc9
Show file tree
Hide file tree
Showing 3 changed files with 581 additions and 78 deletions.
219 changes: 147 additions & 72 deletions preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"reflect"
"strings"
)

func getRealValue(value reflect.Value, field string) interface{} {
Expand All @@ -20,90 +21,139 @@ func equalAsString(a interface{}, b interface{}) bool {
}

func Preload(scope *Scope) {
preloadMap := map[string]bool{}
if scope.Search.preload != nil {
fields := scope.Fields()
isSlice := scope.IndirectValue().Kind() == reflect.Slice

for key, conditions := range scope.Search.preload {
for _, field := range fields {
if field.Name == key && field.Relationship != nil {
results := makeSlice(field.Struct.Type)
relation := field.Relationship
primaryName := scope.PrimaryField().Name
associationPrimaryKey := scope.New(results).PrimaryField().Name
for _, preload := range scope.Search.preload {
schema, conditions := preload.schema, preload.conditions
keys := strings.Split(schema, ".")
currentScope := scope
currentFields := fields
currentIsSlice := isSlice
originalConditions := conditions
conditions = []interface{}{}
for i, key := range keys {
// log.Printf("--> %+v\n", key)
if !preloadMap[strings.Join(keys[:i+1], ".")] {
if i == len(keys)-1 {
// log.Printf("--> %+v\n", originalConditions)
conditions = originalConditions
}

var found bool
for _, field := range currentFields {
if field.Name == key && field.Relationship != nil {
found = true
// log.Printf("--> %+v\n", field.Name)
results := makeSlice(field.Struct.Type)
relation := field.Relationship
primaryName := currentScope.PrimaryField().Name
associationPrimaryKey := currentScope.New(results).PrimaryField().Name

switch relation.Kind {
case "has_one":
if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
switch relation.Kind {
case "has_one":
if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName))
currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)

resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if isSlice {
value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if currentIsSlice {
value := getRealValue(result, relation.ForeignFieldName)
objects := currentScope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break
}
}
} else {
// log.Printf("--> %+v\n", result.Interface())
err := currentScope.SetColumn(field, result)
if err != nil {
scope.Err(err)
return
}
// printutils.PrettyPrint(currentScope.Value)
}
}
} else {
scope.SetColumn(field, result)
// printutils.PrettyPrint(currentScope.Value)
}
}
}
case "has_many":
if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
resultValues := reflect.Indirect(reflect.ValueOf(results))
if isSlice {
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, primaryName), value) {
f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result))
break
case "has_many":
// log.Printf("--> %+v\n", key)
if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName))
currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
resultValues := reflect.Indirect(reflect.ValueOf(results))
if currentIsSlice {
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
value := getRealValue(result, relation.ForeignFieldName)
objects := currentScope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, primaryName), value) {
f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result))
break
}
}
}
// printutils.PrettyPrint(currentScope.IndirectValue().Interface())
} else {
currentScope.SetColumn(field, resultValues)
}
}
} else {
scope.SetColumn(field, resultValues)
}
}
case "belongs_to":
if primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 {
scope.NewDB().Where(primaryKeys).Find(results, conditions...)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if isSlice {
value := getRealValue(result, associationPrimaryKey)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
object.FieldByName(field.Name).Set(result)
case "belongs_to":
if primaryKeys := currentScope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 {
currentScope.NewDB().Where(primaryKeys).Find(results, conditions...)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if currentIsSlice {
value := getRealValue(result, associationPrimaryKey)
objects := currentScope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
object.FieldByName(field.Name).Set(result)
}
}
} else {
currentScope.SetColumn(field, result)
}
}
} else {
scope.SetColumn(field, result)
}
case "many_to_many":
// currentScope.Err(errors.New("not supported relation"))
fallthrough
default:
currentScope.Err(errors.New("not supported relation"))
}
break
}
case "many_to_many":
scope.Err(errors.New("not supported relation"))
default:
scope.Err(errors.New("not supported relation"))
}
break

if !found {
value := reflect.ValueOf(currentScope.Value)
if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
value = value.Index(0).Elem()
}
scope.Err(fmt.Errorf("can't found field %s in %s", key, value.Type()))
return
}

preloadMap[strings.Join(keys[:i+1], ".")] = true
}

if i < len(keys)-1 {
// TODO: update current scope
currentScope = currentScope.getColumnsAsScope(key)
currentFields = currentScope.Fields()
currentIsSlice = currentScope.IndirectValue().Kind() == reflect.Slice
}
}
}
Expand All @@ -120,19 +170,44 @@ func makeSlice(typ reflect.Type) interface{} {
return slice.Interface()
}

func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) {
func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) {
values := scope.IndirectValue()
switch values.Kind() {
case reflect.Slice:
primaryKeyMap := map[interface{}]bool{}
for i := 0; i < values.Len(); i++ {
primaryKeyMap[reflect.Indirect(values.Index(i)).FieldByName(column).Interface()] = true
}
for key := range primaryKeyMap {
primaryKeys = append(primaryKeys, key)
columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
}
case reflect.Struct:
return []interface{}{values.FieldByName(column).Interface()}
}
return
}

func (scope *Scope) getColumnsAsScope(column string) *Scope {
values := scope.IndirectValue()
// log.Println(values.Type(), column)
switch values.Kind() {
case reflect.Slice:
fieldType, _ := values.Type().Elem().FieldByName(column)
var columns reflect.Value
if fieldType.Type.Kind() == reflect.Slice {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type.Elem()))).Elem()
} else {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type))).Elem()
}
for i := 0; i < values.Len(); i++ {
column := reflect.Indirect(values.Index(i)).FieldByName(column)
if column.Kind() == reflect.Slice {
for i := 0; i < column.Len(); i++ {
columns = reflect.Append(columns, column.Index(i).Addr())
}
} else {
columns = reflect.Append(columns, column.Addr())
}
}
return scope.New(columns.Interface())
case reflect.Struct:
return scope.New(values.FieldByName(column).Addr().Interface())
}
return nil
}
Loading

0 comments on commit 6d58dc9

Please sign in to comment.