Skip to content

Commit

Permalink
【AutoParallel】Add strategy with more options (#8114)
Browse files Browse the repository at this point in the history
* add strategy

* polish
  • Loading branch information
heavyrain-lzy authored Mar 14, 2024
1 parent 880d2ea commit c406d90
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ def _wrap_for_auto(self, model, train_dataloader):
dist_loader = self._wrap_for_dist_loader(train_dataloader)

if self.args.to_static:
unified_strategy = dist.Strategy()
unified_strategy._from_legacy_strategy(self.args.strategy)
return (
dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=self.args.strategy),
dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=unified_strategy),
dist_loader,
)
else:
Expand Down

0 comments on commit c406d90

Please sign in to comment.