Skip to content

Commit

Permalink
Update training_args.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay authored Nov 6, 2024
1 parent 2f0b407 commit e91e592
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ def is_segment_parallel_supported():
strategy.hybrid_configs["sharding_configs"].split_param = True
assert self.amp_master_grad, "Currently sharding stage1 v2 only support amp_master_grad"

if "enable_release_grads" in sharding_parallel_config:
if "enable_release_grads" in sharding_parallel_config or self.release_grads:
strategy.hybrid_configs["sharding_configs"].release_gradients = True

if self.pipeline_parallel_degree == 1:
Expand Down

0 comments on commit e91e592

Please sign in to comment.