Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support tensor parallel using Pytorch 2.0 & Data loader #3173

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kmehant
Copy link

@kmehant kmehant commented Oct 16, 2024

What does this PR do?

  1. Implements TorchTensorParallelPlugin to support TP with Pytorch 2.0. This work should be seen along with the PR feat: add support for tensor parallel using Pytorch 2.0 transformers#34194.
  2. Modifies dataloader to support passing same samples across TP ranks

Please review in conjunction with huggingface/transformers#34194

Results

See significant improvement in both memory and throughput compared against single gpu training, and FSDP across different settings (checkpointing on/off) and context lengths.

Done on two models

  1. ibm-granite/granite-8b-code-base-128k
  2. codellama/CodeLlama-7b-hf

Tables below show the max cuda memory and throughput for various configurations showing the potential of TP contributed in this PR. There is gains in both memory and throughput.

Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 8192 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 8192 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 8192 1 FALSE 52.4 7675.4
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 8192 1 TRUE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 8192 1 TRUE 29.975586 2256.896
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 8192 1 TRUE 26.5 5935.5
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 16384 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 16384 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 16384 1 FALSE OOM NA
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 16384 1 TRUE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 16384 1 TRUE 36.8 2084.864
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 16384 1 TRUE 33.5 5692.5
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 8192 1 FALSE OOM NA
codellama/CodeLlama-7b-hf FSDP 4 8192 1 FALSE 70.7 3560
codellama/CodeLlama-7b-hf TP (This PR) 4 8192 1 FALSE 42.8 9216
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 8192 1 TRUE 75.3 2849
codellama/CodeLlama-7b-hf FSDP 4 8192 1 TRUE 26.4 5957
codellama/CodeLlama-7b-hf TP (This PR) 4 8192 1 TRUE 21.4 7125
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 16384 1 FALSE OOM NA
codellama/CodeLlama-7b-hf FSDP 4 16384 1 FALSE OOM NA
codellama/CodeLlama-7b-hf TP (This PR) 4 16384 1 FALSE OOM NA
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 16384 1 TRUE 75.3 2599
codellama/CodeLlama-7b-hf FSDP 4 16384 1 TRUE 30.1 2433
codellama/CodeLlama-7b-hf TP (This PR) 4 16384 1 TRUE 26.6 6873

Fixes # (issue)
huggingface/transformers#32470

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

I have cycles to bring in more improvements over this PR to bring in Pytorch TP support to HF. Looking forward. Thank you

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
@kmehant kmehant changed the title feat: support tensor parallel using Pytorch 2.0 feat: support tensor parallel using Pytorch 2.0 & Data loader Oct 24, 2024
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks great to me. We do still need to update this to work with accelerate config however, whcih happens in commands/config and commands/launch. Would you like to do so?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@muellerzr
Copy link
Collaborator

@kmehant if you rebase from main this should fix the failures (tl;dr we had py 3.8 EOL)

@kmehant
Copy link
Author

kmehant commented Oct 29, 2024

@muellerzr Appreciate your response. I would like to bring to your notice the below two points.

  1. This dataloader written to work for the paradigm (call it paradigm 1) of master process fetching the data needed and distributing them to all the worker processes. The more general paradigm (call it paradigm 2) of all the processes fetching their own data sample in TP case it has to be the same batch across the processes is not covered in this PR.
  2. This PR has a soft dependency to apply TP plan over the model since this PR is more like of 2 parts - TP workflow through accelerate plugin + dataloader.
    1. First part of the PR applies TP parallelism to the model like shown here - https://github.com/huggingface/accelerate/pull/3173/files#diff-2d7515874eaecac2687c7fc1a9c720be53f802bf14b4c3dcebe14ad443d075dcR1467 creating a soft dependency over feat: add support for tensor parallel using Pytorch 2.0 transformers#34194 (Part of this would be superseded by Simplify Tensor Parallel implementation with PyTorch TP transformers#34184 that is carrying a different interface to apply TP plan to the model).
    2. second part is the dataloader

For point (1) I can keep this PR simple and allow only for the paradigm 1 and address the paradigm 2 in another PR.
For point (2) I can remove application of TP part from this PR, keeping this simple and independent. The part removed can be added in a separate PR as point (2)(i) is completed.

WDYT?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR, this looks nice. I have a few smaller comments, please take a look.

Also, please ensure that make quality passes.

@@ -1457,6 +1463,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
model.apply_tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_tensor_parallel will be implemented in huggingface/transformers#34194 but only for select model architectures, right? Should we check this and if not present, raise an appropriate error?

@@ -740,7 +742,29 @@ def _fetch_batches(self, iterator):
batches, batch = None, None
# On process 0, we gather the batch to dispatch.
if self.state.process_index == 0:
# if a device mesh is provided extract each dimension (tp and dp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we perform this check already during __init__? submesh_tp and submesh_dp could be stored as self attributes.

@@ -1009,6 +1040,8 @@ def prepare_data_loader(
"If set to true, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
torch_device_mesh (``, *optional*, defaults to `None`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type is missing.

@dataclass
class TorchTensorParallelPlugin:
"""
This plugin is used to enable tensor parallelism using PyTorch 2.0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a reference to further information on TP should be added here. Also:

Suggested change
This plugin is used to enable tensor parallelism using PyTorch 2.0.
This plugin is used to enable tensor parallelism using PyTorch >= 2.0.

)

def __post_init__(self):
from torch.distributed.device_mesh import init_device_mesh
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we perform a check on the minimum PyTorch and transformers versions? Not sure if here is the best place or somewhere else, Zach?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants