Skip to content

Commit

Permalink
Support custom preloading SQL, close go-gorm#598, go-gorm#793, go-gor…
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 15, 2016
1 parent b054f23 commit 5883c70
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 13 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,18 @@ db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users)
//// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to
```

#### Custom Preloading SQL

You could custom preloading SQL by passing in `func(db *gorm.DB) *gorm.DB` (same type as the one used for [Scopes](#scopes)), for example:

```go
db.Preload("Orders", func(db *gorm.DB) *gorm.DB {
return db.Order("orders.amount DESC")
}).Find(&users)
//// SELECT * FROM users;
//// SELECT * FROM orders WHERE user_id IN (1,2,3,4) order by orders.amount DESC;
```

#### Nested Preloading

```go
Expand Down
49 changes: 37 additions & 12 deletions callback_query_preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@ func preloadCallback(scope *Scope) {
}
}

func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
var (
preloadDB = scope.NewDB()
preloadConditions []interface{}
)

for _, condition := range conditions {
if scopes, ok := condition.(func(*DB) *DB); ok {
preloadDB = scopes(preloadDB)
} else {
preloadConditions = append(preloadConditions, condition)
}
}

return preloadDB, preloadConditions
}

// handleHasOnePreload used to preload has one associations
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
relation := field.Relationship
Expand All @@ -83,9 +100,12 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
return
}

// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)

// find relations
results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)

// assign find results
var (
Expand Down Expand Up @@ -119,9 +139,12 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
return
}

// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)

// find relations
results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)

// assign find results
var (
Expand Down Expand Up @@ -151,6 +174,9 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship

// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)

// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
Expand All @@ -159,7 +185,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{

// find relations
results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)

// assign find results
var (
Expand Down Expand Up @@ -205,21 +231,20 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
sourceKeys = append(sourceKeys, key.DBName)
}

// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)

// generate query with join table
newScope := scope.New(reflect.New(fieldType).Interface())
preloadJoinDB := scope.NewDB().Table(newScope.TableName()).Select("*")
preloadJoinDB = joinTableHandler.JoinWith(joinTableHandler, preloadJoinDB, scope.Value)

if primaryField := newScope.PrimaryField(); primaryField != nil {
preloadJoinDB = preloadJoinDB.Order(fmt.Sprintf("%v.%v %v", newScope.QuotedTableName(), newScope.Quote(primaryField.DBName), "ASC"))
}
preloadDB = preloadDB.Table(newScope.TableName()).Select("*")
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)

// preload inline conditions
if len(conditions) > 0 {
preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...)
if len(preloadConditions) > 0 {
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
}

rows, err := preloadJoinDB.Rows()
rows, err := preloadDB.Rows()

if scope.Err(err) != nil {
return
Expand Down
4 changes: 3 additions & 1 deletion preload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,9 @@ func TestNestedManyToManyPreload3(t *testing.T) {
}

var gots []*Level3
if err := DB.Preload("Level2.Level1s").Find(&gots).Error; err != nil {
if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB {
return db.Order("level1.id ASC")
}).Find(&gots).Error; err != nil {
t.Error(err)
}

Expand Down

0 comments on commit 5883c70

Please sign in to comment.