Skip to content

Commit

Permalink
Non blocking support to torch DL's (#30465)
Browse files Browse the repository at this point in the history
* Non blocking support

* Check for optimization

* Doc
  • Loading branch information
muellerzr authored Apr 24, 2024
1 parent 5c57463 commit 6ad9c8f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4361,6 +4361,18 @@ def create_accelerator_and_postprocess(self):
even_batches=accelerator_config.pop("even_batches"),
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
)
non_blocking = accelerator_config.pop("non_blocking")
if not is_accelerate_available("0.30.0"):
if non_blocking:
raise ImportError(
"`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature."
)
else:
if non_blocking and not self.args.dataloader_pin_memory:
logger.warning(
"`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both."
)
dataloader_config.non_blocking = non_blocking
# this would have been updated above, no need for it anymore
accelerator_config.pop("gradient_accumulation_kwargs")

Expand Down
15 changes: 15 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,10 @@ class AcceleratorConfig:
The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`.
sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch.
The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`.
non_blocking (`bool`, *optional*, defaults to `False`):
Whether to use non-blocking CUDA calls to help minimize synchronization during
distributed training with prepared `DataLoader` inputs being moved to device.
Best if used with `pin_memory=True` in the `TrainingArguments`.
"""

Expand Down Expand Up @@ -1284,6 +1288,17 @@ class AcceleratorConfig:
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
},
)

non_blocking: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use non-blocking CUDA calls to help minimize synchronization during "
"distributed training with prepared `DataLoader` inputs being moved to device. "
"Best if used with `pin_memory=True` in the `TrainingArguments`. Requires accelerate "
"v0.30.0."
},
)

gradient_accumulation_kwargs: Optional[Dict] = field(
default=None,
metadata={
Expand Down

0 comments on commit 6ad9c8f

Please sign in to comment.