diff --git a/README.md b/README.md index 5751964f4..0f195561a 100644 --- a/README.md +++ b/README.md @@ -432,6 +432,18 @@ db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users) //// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to ``` +#### Custom Preloading SQL + +You could custom preloading SQL by passing in `func(db *gorm.DB) *gorm.DB` (same type as the one used for [Scopes](#scopes)), for example: + +```go +db.Preload("Orders", func(db *gorm.DB) *gorm.DB { + return db.Order("orders.amount DESC") +}).Find(&users) +//// SELECT * FROM users; +//// SELECT * FROM orders WHERE user_id IN (1,2,3,4) order by orders.amount DESC; +``` + #### Nested Preloading ```go diff --git a/callback_query_preload.go b/callback_query_preload.go index 97591915d..e57caad0b 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -73,6 +73,23 @@ func preloadCallback(scope *Scope) { } } +func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { + var ( + preloadDB = scope.NewDB() + preloadConditions []interface{} + ) + + for _, condition := range conditions { + if scopes, ok := condition.(func(*DB) *DB); ok { + preloadDB = scopes(preloadDB) + } else { + preloadConditions = append(preloadConditions, condition) + } + } + + return preloadDB, preloadConditions +} + // handleHasOnePreload used to preload has one associations func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { relation := field.Relationship @@ -83,9 +100,12 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) return } + // preload conditions + preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) + // find relations results := makeSlice(field.Struct.Type) - scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) + scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) // assign find results var ( @@ -119,9 +139,12 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) return } + // preload conditions + preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) + // find relations results := makeSlice(field.Struct.Type) - scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) + scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) // assign find results var ( @@ -151,6 +174,9 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { relation := field.Relationship + // preload conditions + preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) + // get relations's primary keys primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { @@ -159,7 +185,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ // find relations results := makeSlice(field.Struct.Type) - scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) + scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) // assign find results var ( @@ -205,21 +231,20 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface sourceKeys = append(sourceKeys, key.DBName) } + // preload conditions + preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) + // generate query with join table newScope := scope.New(reflect.New(fieldType).Interface()) - preloadJoinDB := scope.NewDB().Table(newScope.TableName()).Select("*") - preloadJoinDB = joinTableHandler.JoinWith(joinTableHandler, preloadJoinDB, scope.Value) - - if primaryField := newScope.PrimaryField(); primaryField != nil { - preloadJoinDB = preloadJoinDB.Order(fmt.Sprintf("%v.%v %v", newScope.QuotedTableName(), newScope.Quote(primaryField.DBName), "ASC")) - } + preloadDB = preloadDB.Table(newScope.TableName()).Select("*") + preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) // preload inline conditions - if len(conditions) > 0 { - preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...) + if len(preloadConditions) > 0 { + preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) } - rows, err := preloadJoinDB.Rows() + rows, err := preloadDB.Rows() if scope.Err(err) != nil { return diff --git a/preload_test.go b/preload_test.go index 3ba0cf925..8f21bc97b 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1107,7 +1107,9 @@ func TestNestedManyToManyPreload3(t *testing.T) { } var gots []*Level3 - if err := DB.Preload("Level2.Level1s").Find(&gots).Error; err != nil { + if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { + return db.Order("level1.id ASC") + }).Find(&gots).Error; err != nil { t.Error(err) }