-
Notifications
You must be signed in to change notification settings - Fork 961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support tensor parallel using Pytorch 2.0 & Data loader #3173
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks great to me. We do still need to update this to work with accelerate config
however, whcih happens in commands/config
and commands/launch
. Would you like to do so?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@kmehant if you rebase from |
@muellerzr Appreciate your response. I would like to bring to your notice the below two points.
For point (1) I can keep this PR simple and allow only for the paradigm 1 and address the paradigm 2 in another PR. WDYT? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR, this looks nice. I have a few smaller comments, please take a look.
Also, please ensure that make quality
passes.
@@ -1457,6 +1463,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e | |||
) | |||
if self.ddp_handler is not None: | |||
self.ddp_handler.register_comm_hook(model) | |||
elif self.distributed_type == DistributedType.TP: | |||
model.apply_tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apply_tensor_parallel
will be implemented in huggingface/transformers#34194 but only for select model architectures, right? Should we check this and if not present, raise an appropriate error?
@@ -740,7 +742,29 @@ def _fetch_batches(self, iterator): | |||
batches, batch = None, None | |||
# On process 0, we gather the batch to dispatch. | |||
if self.state.process_index == 0: | |||
# if a device mesh is provided extract each dimension (tp and dp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we perform this check already during __init__
? submesh_tp
and submesh_dp
could be stored as self
attributes.
@@ -1009,6 +1040,8 @@ def prepare_data_loader( | |||
"If set to true, the dataloader prepared by the Accelerator will be backed by " | |||
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). | |||
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." | |||
torch_device_mesh (``, *optional*, defaults to `None`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type is missing.
@dataclass | ||
class TorchTensorParallelPlugin: | ||
""" | ||
This plugin is used to enable tensor parallelism using PyTorch 2.0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a reference to further information on TP should be added here. Also:
This plugin is used to enable tensor parallelism using PyTorch 2.0. | |
This plugin is used to enable tensor parallelism using PyTorch >= 2.0. |
) | ||
|
||
def __post_init__(self): | ||
from torch.distributed.device_mesh import init_device_mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we perform a check on the minimum PyTorch and transformers versions? Not sure if here is the best place or somewhere else, Zach?
What does this PR do?
TorchTensorParallelPlugin
to support TP with Pytorch 2.0. This work should be seen along with the PR feat: add support for tensor parallel using Pytorch 2.0 transformers#34194.Please review in conjunction with huggingface/transformers#34194
Results
See significant improvement in both memory and throughput compared against single gpu training, and FSDP across different settings (checkpointing on/off) and context lengths.
Done on two models
Tables below show the max cuda memory and throughput for various configurations showing the potential of TP contributed in this PR. There is gains in both memory and throughput.
Fixes # (issue)
huggingface/transformers#32470
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
I have cycles to bring in more improvements over this PR to bring in Pytorch TP support to HF. Looking forward. Thank you